modeling_tf_led.py 120 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663
  1. # coding=utf-8
  2. # Copyright 2021 Iz Beltagy, Matthew E. Peters, Arman Cohan and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TF 2.0 LED model."""
  16. from __future__ import annotations
  17. import random
  18. from dataclasses import dataclass
  19. from typing import List, Optional, Tuple, Union
  20. import numpy as np
  21. import tensorflow as tf
  22. from ...activations_tf import get_tf_activation
  23. from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions
  24. # Public API
  25. from ...modeling_tf_utils import (
  26. TFModelInputType,
  27. TFPreTrainedModel,
  28. get_initializer,
  29. keras,
  30. keras_serializable,
  31. unpack_inputs,
  32. )
  33. from ...tf_utils import check_embeddings_within_bounds, shape_list, stable_softmax
  34. from ...utils import (
  35. ModelOutput,
  36. add_code_sample_docstrings,
  37. add_start_docstrings,
  38. add_start_docstrings_to_model_forward,
  39. logging,
  40. replace_return_docstrings,
  41. )
  42. from .configuration_led import LEDConfig
  43. logger = logging.get_logger(__name__)
  44. _CHECKPOINT_FOR_DOC = "allenai/led-base-16384"
  45. _CONFIG_FOR_DOC = "LEDConfig"
  46. LARGE_NEGATIVE = -1e8
  47. # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
  48. def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
  49. pad_token_id = tf.cast(pad_token_id, input_ids.dtype)
  50. decoder_start_token_id = tf.cast(decoder_start_token_id, input_ids.dtype)
  51. start_tokens = tf.fill(
  52. (shape_list(input_ids)[0], 1), tf.convert_to_tensor(decoder_start_token_id, input_ids.dtype)
  53. )
  54. shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
  55. # replace possible -100 values in labels by `pad_token_id`
  56. shifted_input_ids = tf.where(
  57. shifted_input_ids == -100,
  58. tf.fill(shape_list(shifted_input_ids), tf.convert_to_tensor(pad_token_id, input_ids.dtype)),
  59. shifted_input_ids,
  60. )
  61. # "Verify that `labels` has only positive values and -100"
  62. assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.constant(0, dtype=input_ids.dtype))
  63. # Make sure the assertion op is called by wrapping the result in an identity no-op
  64. with tf.control_dependencies([assert_gte0]):
  65. shifted_input_ids = tf.identity(shifted_input_ids)
  66. return shifted_input_ids
  67. # Copied from transformers.models.bart.modeling_tf_bart._make_causal_mask
  68. def _make_causal_mask(input_ids_shape: tf.TensorShape, past_key_values_length: int = 0):
  69. """
  70. Make causal mask used for bi-directional self-attention.
  71. """
  72. bsz = input_ids_shape[0]
  73. tgt_len = input_ids_shape[1]
  74. mask = tf.ones((tgt_len, tgt_len)) * LARGE_NEGATIVE
  75. mask_cond = tf.range(shape_list(mask)[-1])
  76. mask = tf.where(mask_cond < tf.reshape(mask_cond + 1, (shape_list(mask)[-1], 1)), 0.0, mask)
  77. if past_key_values_length > 0:
  78. mask = tf.concat([tf.zeros((tgt_len, past_key_values_length)), mask], axis=-1)
  79. return tf.tile(mask[None, None, :, :], (bsz, 1, 1, 1))
  80. # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
  81. def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
  82. """
  83. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  84. """
  85. src_len = shape_list(mask)[1]
  86. tgt_len = tgt_len if tgt_len is not None else src_len
  87. one_cst = tf.constant(1.0)
  88. mask = tf.cast(mask, dtype=one_cst.dtype)
  89. expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
  90. return (one_cst - expanded_mask) * LARGE_NEGATIVE
  91. class TFLEDLearnedPositionalEmbedding(keras.layers.Embedding):
  92. """
  93. This module learns positional embeddings up to a fixed maximum size.
  94. """
  95. def __init__(self, num_embeddings: int, embedding_dim: int, **kwargs):
  96. super().__init__(num_embeddings, embedding_dim, **kwargs)
  97. def call(self, input_shape: tf.TensorShape, past_key_values_length: int = 0):
  98. """Input is expected to be of size [bsz x seqlen]."""
  99. seq_len = input_shape[1]
  100. position_ids = tf.range(seq_len, delta=1, name="range")
  101. position_ids += past_key_values_length
  102. return super().call(tf.cast(position_ids, dtype=tf.int32))
  103. # Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerSelfAttention with TFLongformer->TFLEDEncoder
  104. class TFLEDEncoderSelfAttention(keras.layers.Layer):
  105. def __init__(self, config, layer_id, **kwargs):
  106. super().__init__(**kwargs)
  107. self.config = config
  108. if config.hidden_size % config.num_attention_heads != 0:
  109. raise ValueError(
  110. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  111. f"heads ({config.num_attention_heads}"
  112. )
  113. self.num_heads = config.num_attention_heads
  114. self.head_dim = int(config.hidden_size / config.num_attention_heads)
  115. self.embed_dim = config.hidden_size
  116. self.query = keras.layers.Dense(
  117. self.embed_dim,
  118. kernel_initializer=get_initializer(config.initializer_range),
  119. name="query",
  120. )
  121. self.key = keras.layers.Dense(
  122. self.embed_dim,
  123. kernel_initializer=get_initializer(config.initializer_range),
  124. name="key",
  125. )
  126. self.value = keras.layers.Dense(
  127. self.embed_dim,
  128. kernel_initializer=get_initializer(config.initializer_range),
  129. name="value",
  130. )
  131. # separate projection layers for tokens with global attention
  132. self.query_global = keras.layers.Dense(
  133. self.embed_dim,
  134. kernel_initializer=get_initializer(config.initializer_range),
  135. name="query_global",
  136. )
  137. self.key_global = keras.layers.Dense(
  138. self.embed_dim,
  139. kernel_initializer=get_initializer(config.initializer_range),
  140. name="key_global",
  141. )
  142. self.value_global = keras.layers.Dense(
  143. self.embed_dim,
  144. kernel_initializer=get_initializer(config.initializer_range),
  145. name="value_global",
  146. )
  147. self.dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
  148. self.global_dropout = keras.layers.Dropout(config.attention_probs_dropout_prob)
  149. self.layer_id = layer_id
  150. attention_window = config.attention_window[self.layer_id]
  151. assert (
  152. attention_window % 2 == 0
  153. ), f"`attention_window` for layer {self.layer_id} has to be an even value. Given {attention_window}"
  154. assert (
  155. attention_window > 0
  156. ), f"`attention_window` for layer {self.layer_id} has to be positive. Given {attention_window}"
  157. self.one_sided_attn_window_size = attention_window // 2
  158. def build(self, input_shape=None):
  159. if not self.built:
  160. with tf.name_scope("query_global"):
  161. self.query_global.build((self.config.hidden_size,))
  162. with tf.name_scope("key_global"):
  163. self.key_global.build((self.config.hidden_size,))
  164. with tf.name_scope("value_global"):
  165. self.value_global.build((self.config.hidden_size,))
  166. if self.built:
  167. return
  168. self.built = True
  169. if getattr(self, "query", None) is not None:
  170. with tf.name_scope(self.query.name):
  171. self.query.build([None, None, self.config.hidden_size])
  172. if getattr(self, "key", None) is not None:
  173. with tf.name_scope(self.key.name):
  174. self.key.build([None, None, self.config.hidden_size])
  175. if getattr(self, "value", None) is not None:
  176. with tf.name_scope(self.value.name):
  177. self.value.build([None, None, self.config.hidden_size])
  178. if getattr(self, "query_global", None) is not None:
  179. with tf.name_scope(self.query_global.name):
  180. self.query_global.build([None, None, self.config.hidden_size])
  181. if getattr(self, "key_global", None) is not None:
  182. with tf.name_scope(self.key_global.name):
  183. self.key_global.build([None, None, self.config.hidden_size])
  184. if getattr(self, "value_global", None) is not None:
  185. with tf.name_scope(self.value_global.name):
  186. self.value_global.build([None, None, self.config.hidden_size])
  187. def call(
  188. self,
  189. inputs,
  190. training=False,
  191. ):
  192. """
  193. LongformerSelfAttention expects *len(hidden_states)* to be multiple of *attention_window*. Padding to
  194. *attention_window* happens in LongformerModel.forward to avoid redoing the padding on each layer.
  195. The *attention_mask* is changed in [`LongformerModel.forward`] from 0, 1, 2 to:
  196. - -10000: no attention
  197. - 0: local attention
  198. - +10000: global attention
  199. """
  200. # retrieve input args
  201. (
  202. hidden_states,
  203. attention_mask,
  204. layer_head_mask,
  205. is_index_masked,
  206. is_index_global_attn,
  207. is_global_attn,
  208. ) = inputs
  209. # project hidden states
  210. query_vectors = self.query(hidden_states)
  211. key_vectors = self.key(hidden_states)
  212. value_vectors = self.value(hidden_states)
  213. batch_size, seq_len, embed_dim = shape_list(hidden_states)
  214. tf.debugging.assert_equal(
  215. embed_dim,
  216. self.embed_dim,
  217. message=f"hidden_states should have embed_dim = {self.embed_dim}, but has {embed_dim}",
  218. )
  219. # normalize query
  220. query_vectors /= tf.math.sqrt(tf.cast(self.head_dim, dtype=query_vectors.dtype))
  221. query_vectors = tf.reshape(query_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
  222. key_vectors = tf.reshape(key_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
  223. # attn_probs = (batch_size, seq_len, num_heads, window*2+1)
  224. attn_scores = self._sliding_chunks_query_key_matmul(
  225. query_vectors, key_vectors, self.one_sided_attn_window_size
  226. )
  227. # values to pad for attention probs
  228. remove_from_windowed_attention_mask = attention_mask != 0
  229. # cast to fp32/fp16 then replace 1's with -inf
  230. float_mask = tf.cast(remove_from_windowed_attention_mask, dtype=query_vectors.dtype) * LARGE_NEGATIVE
  231. # diagonal mask with zeros everywhere and -inf inplace of padding
  232. diagonal_mask = self._sliding_chunks_query_key_matmul(
  233. tf.ones(shape_list(attention_mask)),
  234. float_mask,
  235. self.one_sided_attn_window_size,
  236. )
  237. # pad local attention probs
  238. attn_scores += diagonal_mask
  239. tf.debugging.assert_equal(
  240. shape_list(attn_scores),
  241. [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1],
  242. message=(
  243. f"attn_probs should be of size ({batch_size}, {seq_len}, {self.num_heads},"
  244. f" {self.one_sided_attn_window_size * 2 + 1}), but is of size {shape_list(attn_scores)}"
  245. ),
  246. )
  247. # compute global attn indices required through out forward fn
  248. (
  249. max_num_global_attn_indices,
  250. is_index_global_attn_nonzero,
  251. is_local_index_global_attn_nonzero,
  252. is_local_index_no_global_attn_nonzero,
  253. ) = self._get_global_attn_indices(is_index_global_attn)
  254. # this function is only relevant for global attention
  255. if is_global_attn:
  256. attn_scores = self._concat_with_global_key_attn_probs(
  257. attn_scores=attn_scores,
  258. query_vectors=query_vectors,
  259. key_vectors=key_vectors,
  260. max_num_global_attn_indices=max_num_global_attn_indices,
  261. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  262. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  263. is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
  264. )
  265. attn_probs = stable_softmax(attn_scores, axis=-1)
  266. # softmax sometimes inserts NaN if all positions are masked, replace them with 0
  267. # Make sure to create a mask with the proper shape:
  268. # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
  269. # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
  270. if is_global_attn:
  271. masked_index = tf.tile(
  272. is_index_masked[:, :, None, None],
  273. (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
  274. )
  275. else:
  276. masked_index = tf.tile(
  277. is_index_masked[:, :, None, None],
  278. (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
  279. )
  280. attn_probs = tf.where(
  281. masked_index,
  282. tf.zeros(shape_list(masked_index), dtype=attn_probs.dtype),
  283. attn_probs,
  284. )
  285. if layer_head_mask is not None:
  286. tf.debugging.assert_equal(
  287. shape_list(layer_head_mask),
  288. [self.num_heads],
  289. message=(
  290. f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
  291. f" {shape_list(layer_head_mask)}"
  292. ),
  293. )
  294. attn_probs = tf.reshape(layer_head_mask, (1, 1, -1, 1)) * attn_probs
  295. # apply dropout
  296. attn_probs = self.dropout(attn_probs, training=training)
  297. value_vectors = tf.reshape(value_vectors, (batch_size, seq_len, self.num_heads, self.head_dim))
  298. # if global attention, compute sum of global and local attn
  299. if is_global_attn:
  300. attn_output = self._compute_attn_output_with_global_indices(
  301. value_vectors=value_vectors,
  302. attn_probs=attn_probs,
  303. max_num_global_attn_indices=max_num_global_attn_indices,
  304. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  305. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  306. )
  307. else:
  308. attn_output = self._sliding_chunks_matmul_attn_probs_value(
  309. attn_probs, value_vectors, self.one_sided_attn_window_size
  310. )
  311. tf.debugging.assert_equal(
  312. shape_list(attn_output), [batch_size, seq_len, self.num_heads, self.head_dim], message="Unexpected size"
  313. )
  314. attn_output = tf.reshape(attn_output, (batch_size, seq_len, embed_dim))
  315. # compute value for global attention and overwrite to attention output
  316. if is_global_attn:
  317. attn_output, global_attn_probs = self._compute_global_attn_output_from_hidden(
  318. attn_output=attn_output,
  319. hidden_states=hidden_states,
  320. max_num_global_attn_indices=max_num_global_attn_indices,
  321. layer_head_mask=layer_head_mask,
  322. is_local_index_global_attn_nonzero=is_local_index_global_attn_nonzero,
  323. is_index_global_attn_nonzero=is_index_global_attn_nonzero,
  324. is_local_index_no_global_attn_nonzero=is_local_index_no_global_attn_nonzero,
  325. is_index_masked=is_index_masked,
  326. training=training,
  327. )
  328. else:
  329. # Leave attn_output unchanged
  330. global_attn_probs = tf.zeros((batch_size, self.num_heads, max_num_global_attn_indices, seq_len))
  331. # make sure that local attention probabilities are set to 0 for indices of global attn
  332. # Make sure to create a mask with the proper shape:
  333. # if is_global_attn==True => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1]
  334. # if is_global_attn==False => [batch_size, seq_len, self.num_heads, self.one_sided_attn_window_size * 2 + 1]
  335. if is_global_attn:
  336. masked_global_attn_index = tf.tile(
  337. is_index_global_attn[:, :, None, None],
  338. (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + max_num_global_attn_indices + 1),
  339. )
  340. else:
  341. masked_global_attn_index = tf.tile(
  342. is_index_global_attn[:, :, None, None],
  343. (1, 1, self.num_heads, self.one_sided_attn_window_size * 2 + 1),
  344. )
  345. attn_probs = tf.where(
  346. masked_global_attn_index,
  347. tf.zeros(shape_list(masked_global_attn_index), dtype=attn_probs.dtype),
  348. attn_probs,
  349. )
  350. outputs = (attn_output, attn_probs, global_attn_probs)
  351. return outputs
  352. def _sliding_chunks_query_key_matmul(self, query, key, window_overlap):
  353. """
  354. Matrix multiplication of query and key tensors using with a sliding window attention pattern. This
  355. implementation splits the input into overlapping chunks of size 2w (e.g. 512 for pretrained Longformer) with an
  356. overlap of size window_overlap
  357. """
  358. batch_size, seq_len, num_heads, head_dim = shape_list(query)
  359. tf.debugging.assert_equal(
  360. seq_len % (window_overlap * 2),
  361. 0,
  362. message=f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}",
  363. )
  364. tf.debugging.assert_equal(
  365. shape_list(query),
  366. shape_list(key),
  367. message=(
  368. f"Shape of query and key should be equal, but got query: {shape_list(query)} and key:"
  369. f" {shape_list(key)}"
  370. ),
  371. )
  372. chunks_count = seq_len // window_overlap - 1
  373. # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
  374. query = tf.reshape(
  375. tf.transpose(query, (0, 2, 1, 3)),
  376. (batch_size * num_heads, seq_len, head_dim),
  377. )
  378. key = tf.reshape(tf.transpose(key, (0, 2, 1, 3)), (batch_size * num_heads, seq_len, head_dim))
  379. chunked_query = self._chunk(query, window_overlap)
  380. chunked_key = self._chunk(key, window_overlap)
  381. # matrix multiplication
  382. # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
  383. # bcyd: batch_size * num_heads x chunks x 2window_overlap x head_dim
  384. # bcxy: batch_size * num_heads x chunks x 2window_overlap x 2window_overlap
  385. chunked_query = tf.cast(chunked_query, dtype=chunked_key.dtype)
  386. chunked_attention_scores = tf.einsum("bcxd,bcyd->bcxy", chunked_query, chunked_key) # multiply
  387. # convert diagonals into columns
  388. paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 1], [0, 0]])
  389. diagonal_chunked_attention_scores = self._pad_and_transpose_last_two_dims(chunked_attention_scores, paddings)
  390. # allocate space for the overall attention matrix where the chunks are combined. The last dimension
  391. # has (window_overlap * 2 + 1) columns. The first (window_overlap) columns are the window_overlap lower triangles (attention from a word to
  392. # window_overlap previous words). The following column is attention score from each word to itself, then
  393. # followed by window_overlap columns for the upper triangle.
  394. # copy parts from diagonal_chunked_attention_scores into the combined matrix of attentions
  395. # - copying the main diagonal and the upper triangle
  396. # TODO: This code is most likely not very efficient and should be improved
  397. diagonal_attn_scores_up_triang = tf.concat(
  398. [
  399. diagonal_chunked_attention_scores[:, :, :window_overlap, : window_overlap + 1],
  400. diagonal_chunked_attention_scores[:, -1:, window_overlap:, : window_overlap + 1],
  401. ],
  402. axis=1,
  403. )
  404. # - copying the lower triangle
  405. diagonal_attn_scores_low_triang = tf.concat(
  406. [
  407. tf.zeros(
  408. (batch_size * num_heads, 1, window_overlap, window_overlap),
  409. dtype=diagonal_chunked_attention_scores.dtype,
  410. ),
  411. diagonal_chunked_attention_scores[:, :, -(window_overlap + 1) : -1, window_overlap + 1 :],
  412. ],
  413. axis=1,
  414. )
  415. diagonal_attn_scores_first_chunk = tf.concat(
  416. [
  417. tf.roll(
  418. diagonal_chunked_attention_scores,
  419. shift=[1, window_overlap],
  420. axis=[2, 3],
  421. )[:, :, :window_overlap, :window_overlap],
  422. tf.zeros(
  423. (batch_size * num_heads, 1, window_overlap, window_overlap),
  424. dtype=diagonal_chunked_attention_scores.dtype,
  425. ),
  426. ],
  427. axis=1,
  428. )
  429. first_chunk_mask = (
  430. tf.tile(
  431. tf.range(chunks_count + 1, dtype=tf.int64)[None, :, None, None],
  432. (batch_size * num_heads, 1, window_overlap, window_overlap),
  433. )
  434. < 1
  435. )
  436. diagonal_attn_scores_low_triang = tf.where(
  437. first_chunk_mask,
  438. diagonal_attn_scores_first_chunk,
  439. diagonal_attn_scores_low_triang,
  440. )
  441. # merging upper and lower triangle
  442. diagonal_attention_scores = tf.concat(
  443. [diagonal_attn_scores_low_triang, diagonal_attn_scores_up_triang], axis=-1
  444. )
  445. # separate batch_size and num_heads dimensions again
  446. diagonal_attention_scores = tf.transpose(
  447. tf.reshape(
  448. diagonal_attention_scores,
  449. (batch_size, num_heads, seq_len, 2 * window_overlap + 1),
  450. ),
  451. (0, 2, 1, 3),
  452. )
  453. diagonal_attention_scores = self._mask_invalid_locations(diagonal_attention_scores, window_overlap)
  454. return diagonal_attention_scores
  455. @staticmethod
  456. def _mask_invalid_locations(input_tensor, window_overlap):
  457. # create correct upper triangle bool mask
  458. mask_2d_upper = tf.reverse(
  459. tf.linalg.band_part(tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
  460. axis=[0],
  461. )
  462. # pad to full matrix
  463. padding = tf.convert_to_tensor(
  464. [[0, shape_list(input_tensor)[1] - window_overlap], [0, shape_list(input_tensor)[3] - window_overlap - 1]]
  465. )
  466. # create lower mask
  467. mask_2d = tf.pad(mask_2d_upper, padding)
  468. # combine with upper mask
  469. mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])
  470. # broadcast to full matrix
  471. mask_4d = tf.tile(mask_2d[None, :, None, :], (shape_list(input_tensor)[0], 1, 1, 1))
  472. # inf tensor used for masking
  473. inf_tensor = -float("inf") * tf.ones_like(input_tensor)
  474. # mask
  475. input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor, input_tensor)
  476. return input_tensor
  477. def _sliding_chunks_matmul_attn_probs_value(self, attn_probs, value, window_overlap):
  478. """
  479. Same as _sliding_chunks_query_key_matmul but for attn_probs and value tensors. Returned tensor will be of the
  480. same shape as `attn_probs`
  481. """
  482. batch_size, seq_len, num_heads, head_dim = shape_list(value)
  483. tf.debugging.assert_equal(
  484. seq_len % (window_overlap * 2), 0, message="Seq_len has to be multiple of 2 * window_overlap"
  485. )
  486. tf.debugging.assert_equal(
  487. shape_list(attn_probs)[:3],
  488. shape_list(value)[:3],
  489. message="value and attn_probs must have same dims (except head_dim)",
  490. )
  491. tf.debugging.assert_equal(
  492. shape_list(attn_probs)[3],
  493. 2 * window_overlap + 1,
  494. message="attn_probs last dim has to be 2 * window_overlap + 1",
  495. )
  496. chunks_count = seq_len // window_overlap - 1
  497. # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
  498. chunked_attn_probs = tf.reshape(
  499. tf.transpose(attn_probs, (0, 2, 1, 3)),
  500. (
  501. batch_size * num_heads,
  502. seq_len // window_overlap,
  503. window_overlap,
  504. 2 * window_overlap + 1,
  505. ),
  506. )
  507. # group batch_size and num_heads dimensions into one
  508. value = tf.reshape(
  509. tf.transpose(value, (0, 2, 1, 3)),
  510. (batch_size * num_heads, seq_len, head_dim),
  511. )
  512. # pad seq_len with w at the beginning of the sequence and another window overlap at the end
  513. paddings = tf.convert_to_tensor([[0, 0], [window_overlap, window_overlap], [0, 0]])
  514. padded_value = tf.pad(value, paddings, constant_values=-1)
  515. # chunk padded_value into chunks of size 3 window overlap and an overlap of size window overlap
  516. frame_size = 3 * window_overlap * head_dim
  517. frame_hop_size = (shape_list(padded_value)[1] * head_dim - frame_size) // chunks_count
  518. chunked_value = tf.signal.frame(
  519. tf.reshape(padded_value, (batch_size * num_heads, -1)),
  520. frame_size,
  521. frame_hop_size,
  522. )
  523. chunked_value = tf.reshape(
  524. chunked_value,
  525. (batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim),
  526. )
  527. tf.debugging.assert_equal(
  528. shape_list(chunked_value),
  529. [batch_size * num_heads, chunks_count + 1, 3 * window_overlap, head_dim],
  530. message="Chunked value has the wrong shape",
  531. )
  532. chunked_attn_probs = self._pad_and_diagonalize(chunked_attn_probs)
  533. context = tf.einsum("bcwd,bcdh->bcwh", chunked_attn_probs, chunked_value)
  534. context = tf.transpose(
  535. tf.reshape(context, (batch_size, num_heads, seq_len, head_dim)),
  536. (0, 2, 1, 3),
  537. )
  538. return context
  539. @staticmethod
  540. def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
  541. """pads rows and then flips rows and columns"""
  542. hidden_states_padded = tf.pad(
  543. hidden_states_padded, paddings
  544. ) # padding value is not important because it will be overwritten
  545. batch_size, chunk_size, seq_length, hidden_dim = shape_list(hidden_states_padded)
  546. hidden_states_padded = tf.reshape(hidden_states_padded, (batch_size, chunk_size, hidden_dim, seq_length))
  547. return hidden_states_padded
  548. @staticmethod
  549. def _pad_and_diagonalize(chunked_hidden_states):
  550. """
  551. shift every row 1 step right, converting columns into diagonals.
  552. Example:
  553. ```python
  554. chunked_hidden_states: [
  555. 0.4983,
  556. 2.6918,
  557. -0.0071,
  558. 1.0492,
  559. -1.8348,
  560. 0.7672,
  561. 0.2986,
  562. 0.0285,
  563. -0.7584,
  564. 0.4206,
  565. -0.0405,
  566. 0.1599,
  567. 2.0514,
  568. -1.1600,
  569. 0.5372,
  570. 0.2629,
  571. ]
  572. window_overlap = num_rows = 4
  573. ```
  574. (pad & diagonalize) => [ 0.4983, 2.6918, -0.0071, 1.0492, 0.0000, 0.0000, 0.0000
  575. 0.0000, -1.8348, 0.7672, 0.2986, 0.0285, 0.0000, 0.0000 0.0000, 0.0000, -0.7584, 0.4206,
  576. -0.0405, 0.1599, 0.0000 0.0000, 0.0000, 0.0000, 2.0514, -1.1600, 0.5372, 0.2629 ]
  577. """
  578. total_num_heads, num_chunks, window_overlap, hidden_dim = shape_list(chunked_hidden_states)
  579. paddings = tf.convert_to_tensor([[0, 0], [0, 0], [0, 0], [0, window_overlap + 1]])
  580. chunked_hidden_states = tf.pad(
  581. chunked_hidden_states, paddings
  582. ) # total_num_heads x num_chunks x window_overlap x (hidden_dim+window_overlap+1). Padding value is not important because it'll be overwritten
  583. chunked_hidden_states = tf.reshape(
  584. chunked_hidden_states, (total_num_heads, num_chunks, -1)
  585. ) # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap+window_overlap
  586. chunked_hidden_states = chunked_hidden_states[
  587. :, :, :-window_overlap
  588. ] # total_num_heads x num_chunks x window_overlapL+window_overlapwindow_overlap
  589. chunked_hidden_states = tf.reshape(
  590. chunked_hidden_states,
  591. (total_num_heads, num_chunks, window_overlap, window_overlap + hidden_dim),
  592. ) # total_num_heads x num_chunks, window_overlap x hidden_dim+window_overlap
  593. chunked_hidden_states = chunked_hidden_states[:, :, :, :-1]
  594. return chunked_hidden_states
  595. @staticmethod
  596. def _chunk(hidden_states, window_overlap):
  597. """convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
  598. batch_size, seq_length, hidden_dim = shape_list(hidden_states)
  599. num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1
  600. # define frame size and frame stride (similar to convolution)
  601. frame_hop_size = window_overlap * hidden_dim
  602. frame_size = 2 * frame_hop_size
  603. hidden_states = tf.reshape(hidden_states, (batch_size, seq_length * hidden_dim))
  604. # chunk with overlap
  605. chunked_hidden_states = tf.signal.frame(hidden_states, frame_size, frame_hop_size)
  606. tf.debugging.assert_equal(
  607. shape_list(chunked_hidden_states),
  608. [batch_size, num_output_chunks, frame_size],
  609. message=(
  610. "Make sure chunking is correctly applied. `Chunked hidden states should have output dimension"
  611. f" {[batch_size, frame_size, num_output_chunks]}, but got {shape_list(chunked_hidden_states)}."
  612. ),
  613. )
  614. chunked_hidden_states = tf.reshape(
  615. chunked_hidden_states,
  616. (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
  617. )
  618. return chunked_hidden_states
  619. @staticmethod
  620. def _get_global_attn_indices(is_index_global_attn):
  621. """compute global attn indices required throughout forward pass"""
  622. # helper variable
  623. num_global_attn_indices = tf.math.count_nonzero(is_index_global_attn, axis=1)
  624. num_global_attn_indices = tf.cast(num_global_attn_indices, dtype=tf.constant(1).dtype)
  625. # max number of global attn indices in batch
  626. max_num_global_attn_indices = tf.reduce_max(num_global_attn_indices)
  627. # indices of global attn
  628. is_index_global_attn_nonzero = tf.where(is_index_global_attn)
  629. # helper variable
  630. is_local_index_global_attn = tf.range(max_num_global_attn_indices) < tf.expand_dims(
  631. num_global_attn_indices, axis=-1
  632. )
  633. # location of the non-padding values within global attention indices
  634. is_local_index_global_attn_nonzero = tf.where(is_local_index_global_attn)
  635. # location of the padding values within global attention indices
  636. is_local_index_no_global_attn_nonzero = tf.where(tf.math.logical_not(is_local_index_global_attn))
  637. return (
  638. max_num_global_attn_indices,
  639. is_index_global_attn_nonzero,
  640. is_local_index_global_attn_nonzero,
  641. is_local_index_no_global_attn_nonzero,
  642. )
  643. def _concat_with_global_key_attn_probs(
  644. self,
  645. attn_scores,
  646. key_vectors,
  647. query_vectors,
  648. max_num_global_attn_indices,
  649. is_index_global_attn_nonzero,
  650. is_local_index_global_attn_nonzero,
  651. is_local_index_no_global_attn_nonzero,
  652. ):
  653. batch_size = shape_list(key_vectors)[0]
  654. # select global key vectors
  655. global_key_vectors = tf.gather_nd(key_vectors, is_index_global_attn_nonzero)
  656. # create only global key vectors
  657. key_vectors_only_global = tf.scatter_nd(
  658. is_local_index_global_attn_nonzero,
  659. global_key_vectors,
  660. shape=(
  661. batch_size,
  662. max_num_global_attn_indices,
  663. self.num_heads,
  664. self.head_dim,
  665. ),
  666. )
  667. # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
  668. attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs", query_vectors, key_vectors_only_global)
  669. # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
  670. attn_probs_from_global_key_trans = tf.transpose(attn_probs_from_global_key, (0, 3, 1, 2))
  671. mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
  672. shape_list(attn_probs_from_global_key_trans)[-2:]
  673. )
  674. mask = tf.ones(mask_shape) * -10000.0
  675. mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)
  676. # scatter mask
  677. attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
  678. attn_probs_from_global_key_trans,
  679. is_local_index_no_global_attn_nonzero,
  680. mask,
  681. )
  682. # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
  683. attn_probs_from_global_key = tf.transpose(attn_probs_from_global_key_trans, (0, 2, 3, 1))
  684. # concat to attn_probs
  685. # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
  686. attn_scores = tf.concat((attn_probs_from_global_key, attn_scores), axis=-1)
  687. return attn_scores
  688. def _compute_attn_output_with_global_indices(
  689. self,
  690. value_vectors,
  691. attn_probs,
  692. max_num_global_attn_indices,
  693. is_index_global_attn_nonzero,
  694. is_local_index_global_attn_nonzero,
  695. ):
  696. batch_size = shape_list(attn_probs)[0]
  697. # cut local attn probs to global only
  698. attn_probs_only_global = attn_probs[:, :, :, :max_num_global_attn_indices]
  699. # select global value vectors
  700. global_value_vectors = tf.gather_nd(value_vectors, is_index_global_attn_nonzero)
  701. # create only global value vectors
  702. value_vectors_only_global = tf.scatter_nd(
  703. is_local_index_global_attn_nonzero,
  704. global_value_vectors,
  705. shape=(
  706. batch_size,
  707. max_num_global_attn_indices,
  708. self.num_heads,
  709. self.head_dim,
  710. ),
  711. )
  712. # compute attn output only global
  713. attn_output_only_global = tf.einsum("blhs,bshd->blhd", attn_probs_only_global, value_vectors_only_global)
  714. # reshape attn probs
  715. attn_probs_without_global = attn_probs[:, :, :, max_num_global_attn_indices:]
  716. # compute attn output with global
  717. attn_output_without_global = self._sliding_chunks_matmul_attn_probs_value(
  718. attn_probs_without_global, value_vectors, self.one_sided_attn_window_size
  719. )
  720. return attn_output_only_global + attn_output_without_global
  721. def _compute_global_attn_output_from_hidden(
  722. self,
  723. attn_output,
  724. hidden_states,
  725. max_num_global_attn_indices,
  726. layer_head_mask,
  727. is_local_index_global_attn_nonzero,
  728. is_index_global_attn_nonzero,
  729. is_local_index_no_global_attn_nonzero,
  730. is_index_masked,
  731. training,
  732. ):
  733. batch_size, seq_len = shape_list(hidden_states)[:2]
  734. # prepare global hidden states
  735. global_attn_hidden_states = tf.gather_nd(hidden_states, is_index_global_attn_nonzero)
  736. global_attn_hidden_states = tf.scatter_nd(
  737. is_local_index_global_attn_nonzero,
  738. global_attn_hidden_states,
  739. shape=(batch_size, max_num_global_attn_indices, self.embed_dim),
  740. )
  741. # global key, query, value
  742. global_query_vectors_only_global = self.query_global(global_attn_hidden_states)
  743. global_key_vectors = self.key_global(hidden_states)
  744. global_value_vectors = self.value_global(hidden_states)
  745. # normalize
  746. global_query_vectors_only_global /= tf.math.sqrt(
  747. tf.cast(self.head_dim, dtype=global_query_vectors_only_global.dtype)
  748. )
  749. global_query_vectors_only_global = self.reshape_and_transpose(global_query_vectors_only_global, batch_size)
  750. global_key_vectors = self.reshape_and_transpose(global_key_vectors, batch_size)
  751. global_value_vectors = self.reshape_and_transpose(global_value_vectors, batch_size)
  752. # compute attn scores
  753. global_attn_scores = tf.matmul(global_query_vectors_only_global, global_key_vectors, transpose_b=True)
  754. tf.debugging.assert_equal(
  755. shape_list(global_attn_scores),
  756. [batch_size * self.num_heads, max_num_global_attn_indices, seq_len],
  757. message=(
  758. "global_attn_scores have the wrong size. Size should be"
  759. f" {(batch_size * self.num_heads, max_num_global_attn_indices, seq_len)}, but is"
  760. f" {shape_list(global_attn_scores)}."
  761. ),
  762. )
  763. global_attn_scores = tf.reshape(
  764. global_attn_scores,
  765. (batch_size, self.num_heads, max_num_global_attn_indices, seq_len),
  766. )
  767. global_attn_scores_trans = tf.transpose(global_attn_scores, (0, 2, 1, 3))
  768. mask_shape = (shape_list(is_local_index_no_global_attn_nonzero)[0],) + tuple(
  769. shape_list(global_attn_scores_trans)[-2:]
  770. )
  771. global_attn_mask = tf.ones(mask_shape) * -10000.0
  772. global_attn_mask = tf.cast(global_attn_mask, dtype=global_attn_scores_trans.dtype)
  773. # scatter mask
  774. global_attn_scores_trans = tf.tensor_scatter_nd_update(
  775. global_attn_scores_trans,
  776. is_local_index_no_global_attn_nonzero,
  777. global_attn_mask,
  778. )
  779. global_attn_scores = tf.transpose(global_attn_scores_trans, (0, 2, 1, 3))
  780. # mask global attn scores
  781. attn_mask = tf.tile(is_index_masked[:, None, None, :], (1, shape_list(global_attn_scores)[1], 1, 1))
  782. global_attn_scores = tf.where(attn_mask, -10000.0, global_attn_scores)
  783. global_attn_scores = tf.reshape(
  784. global_attn_scores,
  785. (batch_size * self.num_heads, max_num_global_attn_indices, seq_len),
  786. )
  787. # compute global attn probs
  788. global_attn_probs_float = stable_softmax(global_attn_scores, axis=-1)
  789. # apply layer head masking
  790. if layer_head_mask is not None:
  791. tf.debugging.assert_equal(
  792. shape_list(layer_head_mask),
  793. [self.num_heads],
  794. message=(
  795. f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
  796. f" {shape_list(layer_head_mask)}"
  797. ),
  798. )
  799. global_attn_probs_float = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
  800. global_attn_probs_float, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
  801. )
  802. global_attn_probs_float = tf.reshape(
  803. global_attn_probs_float, (batch_size * self.num_heads, max_num_global_attn_indices, seq_len)
  804. )
  805. # dropout
  806. global_attn_probs = self.global_dropout(global_attn_probs_float, training=training)
  807. # global attn output
  808. global_attn_output = tf.matmul(global_attn_probs, global_value_vectors)
  809. tf.debugging.assert_equal(
  810. shape_list(global_attn_output),
  811. [batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim],
  812. message=(
  813. "global_attn_output tensor has the wrong size. Size should be"
  814. f" {(batch_size * self.num_heads, max_num_global_attn_indices, self.head_dim)}, but is"
  815. f" {shape_list(global_attn_output)}."
  816. ),
  817. )
  818. global_attn_output = tf.reshape(
  819. global_attn_output,
  820. (batch_size, self.num_heads, max_num_global_attn_indices, self.head_dim),
  821. )
  822. # get only non zero global attn output
  823. nonzero_global_attn_output = tf.gather_nd(
  824. tf.transpose(global_attn_output, (0, 2, 1, 3)),
  825. is_local_index_global_attn_nonzero,
  826. )
  827. nonzero_global_attn_output = tf.reshape(
  828. nonzero_global_attn_output,
  829. (shape_list(is_local_index_global_attn_nonzero)[0], -1),
  830. )
  831. # overwrite values with global attention
  832. attn_output = tf.tensor_scatter_nd_update(
  833. attn_output, is_index_global_attn_nonzero, nonzero_global_attn_output
  834. )
  835. global_attn_probs = tf.reshape(
  836. global_attn_probs, (batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
  837. )
  838. return attn_output, global_attn_probs
  839. def reshape_and_transpose(self, vector, batch_size):
  840. return tf.reshape(
  841. tf.transpose(
  842. tf.reshape(vector, (batch_size, -1, self.num_heads, self.head_dim)),
  843. (0, 2, 1, 3),
  844. ),
  845. (batch_size * self.num_heads, -1, self.head_dim),
  846. )
  847. class TFLEDEncoderAttention(keras.layers.Layer):
  848. def __init__(self, config, layer_id, **kwargs):
  849. super().__init__(**kwargs)
  850. self.longformer_self_attn = TFLEDEncoderSelfAttention(config, layer_id=layer_id, name="longformer_self_attn")
  851. self.output_dense = keras.layers.Dense(config.d_model, use_bias=True, name="output")
  852. self.config = config
  853. def call(self, inputs, training=False):
  854. (
  855. hidden_states,
  856. attention_mask,
  857. layer_head_mask,
  858. is_index_masked,
  859. is_index_global_attn,
  860. is_global_attn,
  861. ) = inputs
  862. self_outputs = self.longformer_self_attn(
  863. [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
  864. training=training,
  865. )
  866. attention_output = self.output_dense(self_outputs[0], training=training)
  867. outputs = (attention_output,) + self_outputs[1:]
  868. return outputs
  869. def build(self, input_shape=None):
  870. if self.built:
  871. return
  872. self.built = True
  873. if getattr(self, "longformer_self_attn", None) is not None:
  874. with tf.name_scope(self.longformer_self_attn.name):
  875. self.longformer_self_attn.build(None)
  876. if getattr(self, "output_dense", None) is not None:
  877. with tf.name_scope(self.output_dense.name):
  878. self.output_dense.build([None, None, self.config.d_model])
  879. class TFLEDDecoderAttention(keras.layers.Layer):
  880. """Multi-headed attention from "Attention Is All You Need"""
  881. def __init__(
  882. self,
  883. embed_dim: int,
  884. num_heads: int,
  885. dropout: float = 0.0,
  886. is_decoder: bool = False,
  887. bias: bool = True,
  888. **kwargs,
  889. ):
  890. super().__init__(**kwargs)
  891. self.embed_dim = embed_dim
  892. self.num_heads = num_heads
  893. self.dropout = keras.layers.Dropout(dropout)
  894. self.head_dim = embed_dim // num_heads
  895. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  896. self.scaling = self.head_dim**-0.5
  897. self.is_decoder = is_decoder
  898. self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
  899. self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
  900. self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
  901. self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
  902. def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
  903. return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
  904. def call(
  905. self,
  906. hidden_states: tf.Tensor,
  907. key_value_states: tf.Tensor | None = None,
  908. past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
  909. attention_mask: tf.Tensor | None = None,
  910. layer_head_mask: tf.Tensor | None = None,
  911. training=False,
  912. ) -> Tuple[tf.Tensor, tf.Tensor | None]:
  913. """Input shape: Batch x Time x Channel"""
  914. # if key_value_states are provided this layer is used as a cross-attention layer
  915. # for the decoder
  916. is_cross_attention = key_value_states is not None
  917. bsz, tgt_len, embed_dim = shape_list(hidden_states)
  918. # get query proj
  919. query_states = self.q_proj(hidden_states) * self.scaling
  920. # get key, value proj
  921. if is_cross_attention and past_key_value is not None:
  922. # reuse k,v, cross_attentions
  923. key_states = past_key_value[0]
  924. value_states = past_key_value[1]
  925. elif is_cross_attention:
  926. # cross_attentions
  927. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  928. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  929. elif past_key_value is not None:
  930. # reuse k, v, self_attention
  931. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  932. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  933. key_states = tf.concat([past_key_value[0], key_states], axis=2)
  934. value_states = tf.concat([past_key_value[1], value_states], axis=2)
  935. else:
  936. # self_attention
  937. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  938. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  939. if self.is_decoder:
  940. # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
  941. # Further calls to cross_attention layer can then reuse all cross-attention
  942. # key/value_states (first "if" case)
  943. # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
  944. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  945. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  946. # if encoder bi-directional self-attention `past_key_value` is always `None`
  947. past_key_value = (key_states, value_states)
  948. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  949. query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
  950. key_states = tf.reshape(key_states, proj_shape)
  951. value_states = tf.reshape(value_states, proj_shape)
  952. src_len = shape_list(key_states)[1]
  953. attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
  954. tf.debugging.assert_equal(
  955. shape_list(attn_weights),
  956. [bsz * self.num_heads, tgt_len, src_len],
  957. message=(
  958. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  959. f" {shape_list(attn_weights)}"
  960. ),
  961. )
  962. if attention_mask is not None:
  963. tf.debugging.assert_equal(
  964. shape_list(attention_mask),
  965. [bsz, 1, tgt_len, src_len],
  966. message=(
  967. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  968. f" {shape_list(attention_mask)}"
  969. ),
  970. )
  971. attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + tf.cast(
  972. attention_mask, dtype=attn_weights.dtype
  973. )
  974. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  975. attn_weights = stable_softmax(attn_weights, axis=-1)
  976. if layer_head_mask is not None:
  977. tf.debugging.assert_equal(
  978. shape_list(layer_head_mask),
  979. [self.num_heads],
  980. message=(
  981. f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
  982. f" {shape_list(layer_head_mask)}"
  983. ),
  984. )
  985. attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
  986. attn_weights, (bsz, self.num_heads, tgt_len, src_len)
  987. )
  988. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  989. attn_probs = self.dropout(attn_weights, training=training)
  990. attn_output = tf.matmul(attn_probs, value_states)
  991. tf.debugging.assert_equal(
  992. shape_list(attn_output),
  993. [bsz * self.num_heads, tgt_len, self.head_dim],
  994. message=(
  995. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  996. f" {shape_list(attn_output)}"
  997. ),
  998. )
  999. attn_output = tf.transpose(
  1000. tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
  1001. )
  1002. attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
  1003. attn_output = self.out_proj(attn_output)
  1004. attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
  1005. return attn_output, attn_weights, past_key_value
  1006. def build(self, input_shape=None):
  1007. if self.built:
  1008. return
  1009. self.built = True
  1010. if getattr(self, "k_proj", None) is not None:
  1011. with tf.name_scope(self.k_proj.name):
  1012. self.k_proj.build([None, None, self.embed_dim])
  1013. if getattr(self, "q_proj", None) is not None:
  1014. with tf.name_scope(self.q_proj.name):
  1015. self.q_proj.build([None, None, self.embed_dim])
  1016. if getattr(self, "v_proj", None) is not None:
  1017. with tf.name_scope(self.v_proj.name):
  1018. self.v_proj.build([None, None, self.embed_dim])
  1019. if getattr(self, "out_proj", None) is not None:
  1020. with tf.name_scope(self.out_proj.name):
  1021. self.out_proj.build([None, None, self.embed_dim])
  1022. class TFLEDEncoderLayer(keras.layers.Layer):
  1023. def __init__(self, config: LEDConfig, layer_id: int, **kwargs):
  1024. super().__init__(**kwargs)
  1025. self.embed_dim = config.d_model
  1026. self.self_attn = TFLEDEncoderAttention(config, layer_id, name="self_attn")
  1027. self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
  1028. self.dropout = keras.layers.Dropout(config.dropout)
  1029. self.activation_fn = get_tf_activation(config.activation_function)
  1030. self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
  1031. self.fc1 = keras.layers.Dense(config.encoder_ffn_dim, name="fc1")
  1032. self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
  1033. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
  1034. self.config = config
  1035. def call(
  1036. self,
  1037. hidden_states: tf.Tensor,
  1038. attention_mask: tf.Tensor,
  1039. layer_head_mask: tf.Tensor,
  1040. is_index_masked: tf.Tensor,
  1041. is_index_global_attn: tf.Tensor,
  1042. is_global_attn: bool,
  1043. training=False,
  1044. ):
  1045. """
  1046. Args:
  1047. hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
  1048. attention_mask (`tf.Tensor`): attention mask of size
  1049. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  1050. layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
  1051. *(config.encoder_attention_heads,)*.
  1052. """
  1053. residual = hidden_states
  1054. layer_outputs = self.self_attn(
  1055. [hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn],
  1056. training=training,
  1057. )
  1058. hidden_states = layer_outputs[0]
  1059. tf.debugging.assert_equal(
  1060. shape_list(hidden_states),
  1061. shape_list(residual),
  1062. message=f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(hidden_states)}",
  1063. )
  1064. hidden_states = self.dropout(hidden_states, training=training)
  1065. hidden_states = residual + hidden_states
  1066. hidden_states = self.self_attn_layer_norm(hidden_states)
  1067. residual = hidden_states
  1068. hidden_states = self.activation_fn(self.fc1(hidden_states))
  1069. hidden_states = self.activation_dropout(hidden_states, training=training)
  1070. hidden_states = self.fc2(hidden_states)
  1071. hidden_states = self.dropout(hidden_states, training=training)
  1072. hidden_states = residual + hidden_states
  1073. hidden_states = self.final_layer_norm(hidden_states)
  1074. return (hidden_states,) + layer_outputs[1:]
  1075. def build(self, input_shape=None):
  1076. if self.built:
  1077. return
  1078. self.built = True
  1079. if getattr(self, "self_attn", None) is not None:
  1080. with tf.name_scope(self.self_attn.name):
  1081. self.self_attn.build(None)
  1082. if getattr(self, "self_attn_layer_norm", None) is not None:
  1083. with tf.name_scope(self.self_attn_layer_norm.name):
  1084. self.self_attn_layer_norm.build([None, None, self.embed_dim])
  1085. if getattr(self, "fc1", None) is not None:
  1086. with tf.name_scope(self.fc1.name):
  1087. self.fc1.build([None, None, self.embed_dim])
  1088. if getattr(self, "fc2", None) is not None:
  1089. with tf.name_scope(self.fc2.name):
  1090. self.fc2.build([None, None, self.config.encoder_ffn_dim])
  1091. if getattr(self, "final_layer_norm", None) is not None:
  1092. with tf.name_scope(self.final_layer_norm.name):
  1093. self.final_layer_norm.build([None, None, self.embed_dim])
  1094. class TFLEDDecoderLayer(keras.layers.Layer):
  1095. def __init__(self, config: LEDConfig, **kwargs):
  1096. super().__init__(**kwargs)
  1097. self.embed_dim = config.d_model
  1098. self.self_attn = TFLEDDecoderAttention(
  1099. embed_dim=self.embed_dim,
  1100. num_heads=config.decoder_attention_heads,
  1101. dropout=config.attention_dropout,
  1102. name="self_attn",
  1103. is_decoder=True,
  1104. )
  1105. self.dropout = keras.layers.Dropout(config.dropout)
  1106. self.activation_fn = get_tf_activation(config.activation_function)
  1107. self.activation_dropout = keras.layers.Dropout(config.activation_dropout)
  1108. self.self_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
  1109. self.encoder_attn = TFLEDDecoderAttention(
  1110. self.embed_dim,
  1111. config.decoder_attention_heads,
  1112. dropout=config.attention_dropout,
  1113. name="encoder_attn",
  1114. is_decoder=True,
  1115. )
  1116. self.encoder_attn_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
  1117. self.fc1 = keras.layers.Dense(config.decoder_ffn_dim, name="fc1")
  1118. self.fc2 = keras.layers.Dense(self.embed_dim, name="fc2")
  1119. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5, name="final_layer_norm")
  1120. self.config = config
  1121. def call(
  1122. self,
  1123. hidden_states,
  1124. attention_mask: tf.Tensor | None = None,
  1125. encoder_hidden_states: tf.Tensor | None = None,
  1126. encoder_attention_mask: tf.Tensor | None = None,
  1127. layer_head_mask: tf.Tensor | None = None,
  1128. encoder_layer_head_mask: tf.Tensor | None = None,
  1129. past_key_value: Tuple[tf.Tensor] | None = None,
  1130. training=False,
  1131. ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
  1132. """
  1133. Args:
  1134. hidden_states (`tf.Tensor`): input to the layer of shape *(batch, seq_len, embed_dim)*
  1135. attention_mask (`tf.Tensor`): attention mask of size
  1136. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  1137. encoder_hidden_states (`tf.Tensor`):
  1138. cross attention input to the layer of shape *(batch, seq_len, embed_dim)*
  1139. encoder_attention_mask (`tf.Tensor`): encoder attention mask of size
  1140. *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values.
  1141. layer_head_mask (`tf.Tensor`): mask for attention heads in a given layer of size
  1142. *(config.encoder_attention_heads,)*.
  1143. encoder_layer_head_mask (`tf.Tensor`): mask for encoder attention heads in a given layer of
  1144. size *(config.encoder_attention_heads,)*.
  1145. past_key_value (`Tuple(tf.Tensor)`): cached past key and value projection states
  1146. """
  1147. residual = hidden_states
  1148. # Self-Attention
  1149. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  1150. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  1151. # add present self-attn cache to positions 1,2 of present_key_value tuple
  1152. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  1153. hidden_states=hidden_states,
  1154. past_key_value=self_attn_past_key_value,
  1155. attention_mask=attention_mask,
  1156. layer_head_mask=layer_head_mask,
  1157. )
  1158. hidden_states = self.dropout(hidden_states, training=training)
  1159. hidden_states = residual + hidden_states
  1160. hidden_states = self.self_attn_layer_norm(hidden_states)
  1161. # Cross-Attention Block
  1162. cross_attn_present_key_value = None
  1163. cross_attn_weights = None
  1164. if encoder_hidden_states is not None:
  1165. residual = hidden_states
  1166. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  1167. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  1168. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  1169. hidden_states=hidden_states,
  1170. key_value_states=encoder_hidden_states,
  1171. attention_mask=encoder_attention_mask,
  1172. layer_head_mask=encoder_layer_head_mask,
  1173. past_key_value=cross_attn_past_key_value,
  1174. )
  1175. hidden_states = self.dropout(hidden_states, training=training)
  1176. hidden_states = residual + hidden_states
  1177. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  1178. # add cross-attn to positions 3,4 of present_key_value tuple
  1179. present_key_value = present_key_value + cross_attn_present_key_value
  1180. # Fully Connected
  1181. residual = hidden_states
  1182. hidden_states = self.activation_fn(self.fc1(hidden_states))
  1183. hidden_states = self.activation_dropout(hidden_states, training=training)
  1184. hidden_states = self.fc2(hidden_states)
  1185. hidden_states = self.dropout(hidden_states, training=training)
  1186. hidden_states = residual + hidden_states
  1187. hidden_states = self.final_layer_norm(hidden_states)
  1188. return (
  1189. hidden_states,
  1190. self_attn_weights,
  1191. cross_attn_weights,
  1192. present_key_value,
  1193. )
  1194. def build(self, input_shape=None):
  1195. if self.built:
  1196. return
  1197. self.built = True
  1198. if getattr(self, "self_attn", None) is not None:
  1199. with tf.name_scope(self.self_attn.name):
  1200. self.self_attn.build(None)
  1201. if getattr(self, "self_attn_layer_norm", None) is not None:
  1202. with tf.name_scope(self.self_attn_layer_norm.name):
  1203. self.self_attn_layer_norm.build([None, None, self.embed_dim])
  1204. if getattr(self, "encoder_attn", None) is not None:
  1205. with tf.name_scope(self.encoder_attn.name):
  1206. self.encoder_attn.build(None)
  1207. if getattr(self, "encoder_attn_layer_norm", None) is not None:
  1208. with tf.name_scope(self.encoder_attn_layer_norm.name):
  1209. self.encoder_attn_layer_norm.build([None, None, self.embed_dim])
  1210. if getattr(self, "fc1", None) is not None:
  1211. with tf.name_scope(self.fc1.name):
  1212. self.fc1.build([None, None, self.embed_dim])
  1213. if getattr(self, "fc2", None) is not None:
  1214. with tf.name_scope(self.fc2.name):
  1215. self.fc2.build([None, None, self.config.decoder_ffn_dim])
  1216. if getattr(self, "final_layer_norm", None) is not None:
  1217. with tf.name_scope(self.final_layer_norm.name):
  1218. self.final_layer_norm.build([None, None, self.embed_dim])
  1219. class TFLEDPreTrainedModel(TFPreTrainedModel):
  1220. config_class = LEDConfig
  1221. base_model_prefix = "led"
  1222. @property
  1223. def input_signature(self):
  1224. sig = super().input_signature
  1225. sig["global_attention_mask"] = tf.TensorSpec((None, None), tf.int32, name="global_attention_mask")
  1226. return sig
  1227. @dataclass
  1228. # Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerBaseModelOutput with TFLongformer->TFLEDEncoder
  1229. class TFLEDEncoderBaseModelOutput(ModelOutput):
  1230. """
  1231. Base class for Longformer's outputs, with potential hidden states, local and global attentions.
  1232. Args:
  1233. last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1234. Sequence of hidden-states at the output of the last layer of the model.
  1235. hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  1236. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  1237. `(batch_size, sequence_length, hidden_size)`.
  1238. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  1239. attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1240. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x +
  1241. attention_window + 1)`, where `x` is the number of tokens with global attention mask.
  1242. Local attentions weights after the attention softmax, used to compute the weighted average in the
  1243. self-attention heads. Those are the attention weights from every token in the sequence to every token with
  1244. global attention (first `x` values) and to every token in the attention window (remaining `attention_window
  1245. + 1` values). Note that the first `x` values refer to tokens with fixed positions in the text, but the
  1246. remaining `attention_window + 1` values refer to tokens with relative positions: the attention weight of a
  1247. token to itself is located at index `x + attention_window / 2` and the `attention_window / 2` preceding
  1248. (succeeding) values are the attention weights to the `attention_window / 2` preceding (succeeding) tokens.
  1249. If the attention window contains a token with global attention, the attention weight at the corresponding
  1250. index is set to 0; the value should be accessed from the first `x` attention weights. If a token has global
  1251. attention, the attention weights to all other tokens in `attentions` is set to 0, the values should be
  1252. accessed from `global_attentions`.
  1253. global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1254. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
  1255. is the number of tokens with global attention mask.
  1256. Global attentions weights after the attention softmax, used to compute the weighted average in the
  1257. self-attention heads. Those are the attention weights from every token with global attention to every token
  1258. in the sequence.
  1259. """
  1260. last_hidden_state: tf.Tensor = None
  1261. hidden_states: Tuple[tf.Tensor, ...] | None = None
  1262. attentions: Tuple[tf.Tensor, ...] | None = None
  1263. global_attentions: Tuple[tf.Tensor, ...] | None = None
  1264. @dataclass
  1265. class TFLEDSeq2SeqModelOutput(ModelOutput):
  1266. """
  1267. Base class for model encoder's outputs that also contains : pre-computed hidden states that can speed up sequential
  1268. decoding.
  1269. Args:
  1270. last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
  1271. Sequence of hidden-states at the output of the last layer of the decoder of the model.
  1272. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  1273. hidden_size)` is output.
  1274. past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1275. List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
  1276. sequence_length, embed_size_per_head)`).
  1277. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  1278. used (see `past_key_values` input) to speed up sequential decoding.
  1279. decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  1280. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  1281. `(batch_size, sequence_length, hidden_size)`.
  1282. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  1283. decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1284. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1285. sequence_length)`.
  1286. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  1287. self-attention heads.
  1288. cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1289. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1290. sequence_length)`.
  1291. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  1292. weighted average in the cross-attention heads.
  1293. encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1294. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  1295. encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  1296. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  1297. `(batch_size, sequence_length, hidden_size)`.
  1298. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  1299. encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1300. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1301. sequence_length)`.
  1302. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  1303. self-attention heads.
  1304. encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1305. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
  1306. is the number of tokens with global attention mask.
  1307. Global attentions weights after the attention softmax, used to compute the weighted average in the
  1308. self-attention heads. Those are the attention weights from every token with global attention to every token
  1309. in the sequence.
  1310. """
  1311. last_hidden_state: tf.Tensor = None
  1312. past_key_values: List[tf.Tensor] | None = None
  1313. decoder_hidden_states: Tuple[tf.Tensor, ...] | None = None
  1314. decoder_attentions: Tuple[tf.Tensor, ...] | None = None
  1315. cross_attentions: Tuple[tf.Tensor, ...] | None = None
  1316. encoder_last_hidden_state: tf.Tensor | None = None
  1317. encoder_hidden_states: Tuple[tf.Tensor, ...] | None = None
  1318. encoder_attentions: Tuple[tf.Tensor, ...] | None = None
  1319. encoder_global_attentions: Tuple[tf.Tensor, ...] | None = None
  1320. @dataclass
  1321. class TFLEDSeq2SeqLMOutput(ModelOutput):
  1322. """
  1323. Base class for sequence-to-sequence language models outputs.
  1324. Args:
  1325. loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  1326. Language modeling loss.
  1327. logits (`tf.Tensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  1328. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  1329. past_key_values (`List[tf.Tensor]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1330. List of `tf.Tensor` of length `config.n_layers`, with each tensor of shape `(2, batch_size, num_heads,
  1331. sequence_length, embed_size_per_head)`).
  1332. Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
  1333. used (see `past_key_values` input) to speed up sequential decoding.
  1334. decoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  1335. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  1336. `(batch_size, sequence_length, hidden_size)`.
  1337. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
  1338. decoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1339. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1340. sequence_length)`.
  1341. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
  1342. self-attention heads.
  1343. cross_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1344. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1345. sequence_length)`.
  1346. Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  1347. weighted average in the cross-attention heads.
  1348. encoder_last_hidden_state (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1349. Sequence of hidden-states at the output of the last layer of the encoder of the model.
  1350. encoder_hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  1351. Tuple of `tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape
  1352. `(batch_size, sequence_length, hidden_size)`.
  1353. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
  1354. encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1355. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  1356. sequence_length)`.
  1357. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
  1358. self-attention heads.
  1359. encoder_global_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  1360. Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, x)`, where `x`
  1361. is the number of tokens with global attention mask.
  1362. Global attentions weights after the attention softmax, used to compute the weighted average in the
  1363. self-attention heads. Those are the attention weights from every token with global attention to every token
  1364. in the sequence.
  1365. """
  1366. loss: tf.Tensor | None = None
  1367. logits: tf.Tensor = None
  1368. past_key_values: List[tf.Tensor] | None = None
  1369. decoder_hidden_states: Tuple[tf.Tensor, ...] | None = None
  1370. decoder_attentions: Tuple[tf.Tensor, ...] | None = None
  1371. cross_attentions: Tuple[tf.Tensor, ...] | None = None
  1372. encoder_last_hidden_state: tf.Tensor | None = None
  1373. encoder_hidden_states: Tuple[tf.Tensor, ...] | None = None
  1374. encoder_attentions: Tuple[tf.Tensor, ...] | None = None
  1375. encoder_global_attentions: Tuple[tf.Tensor, ...] | None = None
  1376. LED_START_DOCSTRING = r"""
  1377. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  1378. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1379. etc.)
  1380. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  1381. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  1382. behavior.
  1383. <Tip>
  1384. TensorFlow models and layers in `transformers` accept two formats as input:
  1385. - having all inputs as keyword arguments (like PyTorch models), or
  1386. - having all inputs as a list, tuple or dict in the first positional argument.
  1387. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  1388. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  1389. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  1390. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  1391. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  1392. positional argument:
  1393. - a single Tensor with `input_ids` only and nothing else: `model(input_ids)`
  1394. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  1395. `model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
  1396. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  1397. `model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
  1398. Note that when creating models and layers with
  1399. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  1400. about any of this, as you can just pass inputs like you would to any other Python function!
  1401. </Tip>
  1402. Args:
  1403. config ([`LEDConfig`]): Model configuration class with all the parameters of the model.
  1404. Initializing with a config file does not load the weights associated with the model, only the
  1405. configuration. Check out the [`~TFPreTrainedModel.from_pretrained`] method to load the model weights.
  1406. """
  1407. LED_INPUTS_DOCSTRING = r"""
  1408. Args:
  1409. input_ids (`tf.Tensor` of shape `({0})`):
  1410. Indices of input sequence tokens in the vocabulary.
  1411. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1412. [`PreTrainedTokenizer.__call__`] for details.
  1413. [What are input IDs?](../glossary#input-ids)
  1414. attention_mask (`tf.Tensor` of shape `({0})`, *optional*):
  1415. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1416. - 1 for tokens that are **not masked**,
  1417. - 0 for tokens that are **masked**.
  1418. [What are attention masks?](../glossary#attention-mask)
  1419. decoder_input_ids (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1420. Indices of decoder input sequence tokens in the vocabulary.
  1421. Indices can be obtained using [`LedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1422. [`PreTrainedTokenizer.__call__`] for details.
  1423. [What are input IDs?](../glossary#input-ids)
  1424. LED uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1425. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1426. decoder_attention_mask (`tf.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1427. will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
  1428. head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  1429. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  1430. - 1 indicates the head is **not masked**,
  1431. - 0 indicates the head is **masked**.
  1432. decoder_head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1433. Mask to nullify selected heads of the attention modules in the decoder. Mask values selected in `[0, 1]`:
  1434. - 1 indicates the head is **not masked**,
  1435. - 0 indicates the head is **masked**.
  1436. encoder_outputs (`tf.Tensor`, *optional*):
  1437. hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1438. of shape `(batch_size, sequence_length, hidden_size)` is a sequence of
  1439. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`)
  1440. contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1441. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1442. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1443. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1444. use_cache (`bool`, *optional*, defaults to `True`):
  1445. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1446. `past_key_values`). Set to `False` during training, `True` during generation
  1447. output_attentions (`bool`, *optional*):
  1448. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1449. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  1450. config will be used instead.
  1451. output_hidden_states (`bool`, *optional*):
  1452. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1453. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  1454. used instead.
  1455. return_dict (`bool`, *optional*):
  1456. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  1457. eager mode, in graph mode the value will always be set to True.
  1458. training (`bool`, *optional*, defaults to `False`):
  1459. Whether or not to use the model in training mode (some modules like dropout modules have different
  1460. behaviors between training and evaluation).
  1461. """
  1462. @keras_serializable
  1463. class TFLEDEncoder(keras.layers.Layer):
  1464. config_class = LEDConfig
  1465. """
  1466. Transformer encoder consisting of *config.encoder_layers* self-attention layers. Each layer is a
  1467. [`TFLEDEncoderLayer`].
  1468. Args:
  1469. config: LEDConfig
  1470. """
  1471. def __init__(self, config: LEDConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
  1472. super().__init__(**kwargs)
  1473. self.config = config
  1474. self.dropout = keras.layers.Dropout(config.dropout)
  1475. if config.encoder_layerdrop > 0:
  1476. logger.warning("Layerdrop is currently disabled in TFLED models.")
  1477. self.layerdrop = 0.0
  1478. self.padding_idx = config.pad_token_id
  1479. if isinstance(config.attention_window, int):
  1480. assert config.attention_window % 2 == 0, "`config.attention_window` has to be an even value"
  1481. assert config.attention_window > 0, "`config.attention_window` has to be positive"
  1482. config.attention_window = [config.attention_window] * config.num_hidden_layers # one value per layer
  1483. else:
  1484. assert len(config.attention_window) == config.num_hidden_layers, (
  1485. "`len(config.attention_window)` should equal `config.num_hidden_layers`. "
  1486. f"Expected {config.num_hidden_layers}, given {len(config.attention_window)}"
  1487. )
  1488. self.attention_window = config.attention_window
  1489. self.embed_tokens = embed_tokens
  1490. self.embed_positions = TFLEDLearnedPositionalEmbedding(
  1491. config.max_encoder_position_embeddings,
  1492. config.d_model,
  1493. name="embed_positions",
  1494. )
  1495. self.layers = [TFLEDEncoderLayer(config, i, name=f"layers.{i}") for i in range(config.encoder_layers)]
  1496. self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
  1497. self.embed_dim = config.d_model
  1498. def get_embed_tokens(self):
  1499. return self.embed_tokens
  1500. def set_embed_tokens(self, embed_tokens):
  1501. self.embed_tokens = embed_tokens
  1502. @unpack_inputs
  1503. def call(
  1504. self,
  1505. input_ids=None,
  1506. inputs_embeds=None,
  1507. attention_mask=None,
  1508. global_attention_mask=None,
  1509. head_mask=None,
  1510. output_attentions=None,
  1511. output_hidden_states=None,
  1512. return_dict=None,
  1513. training=False,
  1514. ):
  1515. """
  1516. Args:
  1517. input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
  1518. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1519. provide it.
  1520. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1521. [`PreTrainedTokenizer.__call__`] for details.
  1522. [What are input IDs?](../glossary#input-ids)
  1523. attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1524. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1525. - 1 for tokens that are **not masked**,
  1526. - 0 for tokens that are **masked**.
  1527. [What are attention masks?](../glossary#attention-mask)
  1528. head_mask (`tf.Tensor` of shape `(num_layers, num_heads)`, *optional*):
  1529. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1530. - 1 indicates the head is **not masked**,
  1531. - 0 indicates the head is **masked**.
  1532. inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1533. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1534. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1535. than the model's internal embedding lookup matrix.
  1536. output_attentions (`bool`, *optional*):
  1537. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1538. returned tensors for more detail.
  1539. output_hidden_states (`bool`, *optional*):
  1540. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1541. for more detail.
  1542. return_dict (`bool`, *optional*):
  1543. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1544. """
  1545. if input_ids is not None and inputs_embeds is not None:
  1546. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1547. elif input_ids is not None:
  1548. input_shape = shape_list(input_ids)
  1549. check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
  1550. inputs_embeds = self.embed_tokens(input_ids)
  1551. elif inputs_embeds is not None:
  1552. input_shape = shape_list(inputs_embeds)[:-1]
  1553. else:
  1554. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1555. if attention_mask is None:
  1556. attention_mask = tf.fill(input_shape, 1)
  1557. # merge `global_attention_mask` and `attention_mask`
  1558. if global_attention_mask is not None:
  1559. attention_mask = attention_mask * tf.cast((global_attention_mask + 1), dtype=attention_mask.dtype)
  1560. padding_len, input_ids, attention_mask, inputs_embeds = self._pad_to_window_size(
  1561. input_ids=input_ids,
  1562. attention_mask=attention_mask,
  1563. inputs_embeds=inputs_embeds,
  1564. pad_token_id=self.padding_idx,
  1565. )
  1566. input_shape = shape_list(attention_mask)
  1567. # is index masked or global attention
  1568. is_index_masked = tf.math.less(tf.cast(attention_mask, tf.int8), 1)
  1569. is_index_global_attn = tf.math.greater(tf.cast(attention_mask, tf.int8), 1)
  1570. is_global_attn = tf.math.reduce_any(is_index_global_attn)
  1571. embed_pos = self.embed_positions(input_shape)
  1572. hidden_states = inputs_embeds + embed_pos
  1573. hidden_states = self.layernorm_embedding(hidden_states)
  1574. hidden_states = self.dropout(hidden_states, training=training)
  1575. # check attention mask and invert
  1576. if attention_mask is not None:
  1577. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1578. attention_mask = _expand_mask(attention_mask)[:, 0, 0, :]
  1579. attention_mask = attention_mask[:, :, None, None]
  1580. encoder_states = () if output_hidden_states else None
  1581. all_attentions = all_global_attentions = () if output_attentions else None
  1582. # check if head_mask has a correct number of layers specified if desired
  1583. if head_mask is not None:
  1584. tf.debugging.assert_equal(
  1585. shape_list(head_mask)[0],
  1586. len(self.layers),
  1587. message=(
  1588. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  1589. f" {shape_list(head_mask)[0]}."
  1590. ),
  1591. )
  1592. # encoder layers
  1593. for idx, encoder_layer in enumerate(self.layers):
  1594. if output_hidden_states:
  1595. hidden_states_to_add = self.compute_hidden_states(hidden_states, padding_len)
  1596. encoder_states = encoder_states + (hidden_states_to_add,)
  1597. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  1598. dropout_probability = random.uniform(0, 1)
  1599. if training and (dropout_probability < self.layerdrop): # skip the layer
  1600. continue
  1601. layer_outputs = encoder_layer(
  1602. hidden_states=hidden_states,
  1603. attention_mask=attention_mask,
  1604. layer_head_mask=head_mask[idx] if head_mask is not None else None,
  1605. is_index_masked=is_index_masked,
  1606. is_index_global_attn=is_index_global_attn,
  1607. is_global_attn=is_global_attn,
  1608. )
  1609. hidden_states = layer_outputs[0]
  1610. if output_attentions:
  1611. # bzs x seq_len x num_attn_heads x (num_global_attn + attention_window_len + 1) => bzs x num_attn_heads x seq_len x (num_global_attn + attention_window_len + 1)
  1612. all_attentions = all_attentions + (tf.transpose(layer_outputs[1], (0, 2, 1, 3)),)
  1613. # bzs x num_attn_heads x num_global_attn x seq_len => bzs x num_attn_heads x seq_len x num_global_attn
  1614. all_global_attentions = all_global_attentions + (tf.transpose(layer_outputs[2], (0, 1, 3, 2)),)
  1615. # undo padding
  1616. # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
  1617. hidden_states = self.compute_hidden_states(hidden_states, padding_len)
  1618. # undo padding
  1619. if output_attentions:
  1620. all_attentions = (
  1621. tuple([state[:, :, :-padding_len, :] for state in all_attentions])
  1622. if padding_len > 0
  1623. else all_attentions
  1624. )
  1625. if output_hidden_states:
  1626. encoder_states = encoder_states + (hidden_states,)
  1627. if not return_dict:
  1628. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  1629. return TFLEDEncoderBaseModelOutput(
  1630. last_hidden_state=hidden_states,
  1631. hidden_states=encoder_states,
  1632. attentions=all_attentions,
  1633. global_attentions=all_global_attentions,
  1634. )
  1635. @tf.function
  1636. def compute_hidden_states(self, hidden_states, padding_len):
  1637. return hidden_states[:, :-padding_len] if padding_len > 0 else hidden_states
  1638. def _pad_to_window_size(
  1639. self,
  1640. input_ids,
  1641. attention_mask,
  1642. inputs_embeds,
  1643. pad_token_id,
  1644. ):
  1645. """A helper function to pad tokens and mask to work with implementation of Longformer selfattention."""
  1646. # padding
  1647. attention_window = (
  1648. self.attention_window if isinstance(self.attention_window, int) else max(self.attention_window)
  1649. )
  1650. assert attention_window % 2 == 0, f"`attention_window` should be an even value. Given {attention_window}"
  1651. input_shape = shape_list(input_ids) if input_ids is not None else shape_list(inputs_embeds)
  1652. batch_size, seq_len = input_shape[:2]
  1653. padding_len = (attention_window - seq_len % attention_window) % attention_window
  1654. if padding_len > 0:
  1655. logger.warning_once(
  1656. f"Input ids are automatically padded from {seq_len} to {seq_len + padding_len} to be a multiple of "
  1657. f"`config.attention_window`: {attention_window}"
  1658. )
  1659. paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])
  1660. if input_ids is not None:
  1661. input_ids = tf.pad(input_ids, paddings, constant_values=pad_token_id)
  1662. if inputs_embeds is not None:
  1663. if padding_len > 0:
  1664. input_ids_padding = tf.fill((batch_size, padding_len), pad_token_id)
  1665. inputs_embeds_padding = self.embed_tokens(input_ids_padding)
  1666. inputs_embeds = tf.concat([inputs_embeds, inputs_embeds_padding], axis=-2)
  1667. attention_mask = tf.pad(attention_mask, paddings, constant_values=False) # no attention on the padding tokens
  1668. return (
  1669. padding_len,
  1670. input_ids,
  1671. attention_mask,
  1672. inputs_embeds,
  1673. )
  1674. def build(self, input_shape=None):
  1675. if self.built:
  1676. return
  1677. self.built = True
  1678. if getattr(self, "embed_positions", None) is not None:
  1679. with tf.name_scope(self.embed_positions.name):
  1680. self.embed_positions.build(None)
  1681. if getattr(self, "layernorm_embedding", None) is not None:
  1682. with tf.name_scope(self.layernorm_embedding.name):
  1683. self.layernorm_embedding.build([None, None, self.embed_dim])
  1684. if getattr(self, "layers", None) is not None:
  1685. for layer in self.layers:
  1686. with tf.name_scope(layer.name):
  1687. layer.build(None)
  1688. @keras_serializable
  1689. class TFLEDDecoder(keras.layers.Layer):
  1690. config_class = LEDConfig
  1691. """
  1692. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`TFLEDDecoderLayer`]
  1693. Args:
  1694. config: LEDConfig
  1695. embed_tokens: output embedding
  1696. """
  1697. def __init__(self, config: LEDConfig, embed_tokens: Optional[keras.layers.Embedding] = None, **kwargs):
  1698. super().__init__(**kwargs)
  1699. self.config = config
  1700. self.padding_idx = config.pad_token_id
  1701. self.embed_tokens = embed_tokens
  1702. if config.decoder_layerdrop > 0:
  1703. logger.warning("Layerdrop is currently disabled in TFLED models.")
  1704. self.layerdrop = 0.0
  1705. self.embed_positions = TFLEDLearnedPositionalEmbedding(
  1706. config.max_decoder_position_embeddings,
  1707. config.d_model,
  1708. name="embed_positions",
  1709. )
  1710. self.layers = [TFLEDDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
  1711. self.layernorm_embedding = keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm_embedding")
  1712. self.dropout = keras.layers.Dropout(config.dropout)
  1713. def set_embed_tokens(self, embed_tokens):
  1714. self.embed_tokens = embed_tokens
  1715. @unpack_inputs
  1716. def call(
  1717. self,
  1718. input_ids=None,
  1719. inputs_embeds=None,
  1720. attention_mask=None,
  1721. encoder_hidden_states=None,
  1722. encoder_attention_mask=None,
  1723. head_mask=None,
  1724. encoder_head_mask=None,
  1725. past_key_values=None,
  1726. use_cache=None,
  1727. output_attentions=None,
  1728. output_hidden_states=None,
  1729. return_dict=None,
  1730. training=False,
  1731. ):
  1732. r"""
  1733. Args:
  1734. input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
  1735. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  1736. provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1737. [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
  1738. attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1739. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1740. - 1 for tokens that are **not masked**,
  1741. - 0 for tokens that are **masked**.
  1742. [What are attention masks?](../glossary#attention-mask)
  1743. encoder_hidden_states (`tf.Tensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  1744. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1745. of the decoder.
  1746. encoder_attention_mask (`tf.Tensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  1747. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  1748. selected in `[0, 1]`:
  1749. - 1 for tokens that are **not masked**,
  1750. - 0 for tokens that are **masked**.
  1751. [What are attention masks?](../glossary#attention-mask)
  1752. head_mask (`tf.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1753. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1754. - 1 indicates the head is **not masked**,
  1755. - 0 indicates the head is **masked**.
  1756. encoder_head_mask (`tf.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  1757. Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
  1758. on hidden heads. Mask values selected in `[0, 1]`:
  1759. - 1 indicates the head is **not masked**,
  1760. - 0 indicates the head is **masked**.
  1761. past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1762. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
  1763. decoding. If `past_key_values` are used, the user can optionally input only the last
  1764. `decoder_input_ids` (those that don't have their past key value states given to this model) of shape
  1765. `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1766. inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1767. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1768. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1769. than the model's internal embedding lookup matrix.
  1770. output_attentions (`bool`, *optional*):
  1771. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1772. returned tensors for more detail.
  1773. output_hidden_states (`bool`, *optional*):
  1774. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1775. for more detail.
  1776. return_dict (`bool`, *optional*):
  1777. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1778. """
  1779. if input_ids is not None and inputs_embeds is not None:
  1780. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  1781. elif input_ids is not None:
  1782. input_shape = shape_list(input_ids)
  1783. elif inputs_embeds is not None:
  1784. input_shape = shape_list(inputs_embeds)[:-1]
  1785. else:
  1786. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  1787. past_key_values_length = shape_list(past_key_values[0][0])[2] if past_key_values is not None else 0
  1788. # embed positions
  1789. positions = self.embed_positions(input_shape, past_key_values_length)
  1790. if inputs_embeds is None:
  1791. check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
  1792. inputs_embeds = self.embed_tokens(input_ids)
  1793. hidden_states = inputs_embeds
  1794. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1795. if input_shape[-1] > 1:
  1796. combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
  1797. else:
  1798. combined_attention_mask = _expand_mask(
  1799. tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
  1800. )
  1801. if attention_mask is not None and input_shape[-1] > 1:
  1802. combined_attention_mask = combined_attention_mask + _expand_mask(attention_mask, tgt_len=input_shape[-1])
  1803. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1804. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1805. encoder_attention_mask = _expand_mask(encoder_attention_mask, tgt_len=input_shape[-1])
  1806. hidden_states = self.layernorm_embedding(hidden_states + positions)
  1807. hidden_states = self.dropout(hidden_states, training=training)
  1808. # decoder layers
  1809. all_hidden_states = ()
  1810. all_self_attns = ()
  1811. all_cross_attentions = ()
  1812. present_key_values = ()
  1813. # check if head_mask has a correct number of layers specified if desired
  1814. if head_mask is not None:
  1815. tf.debugging.assert_equal(
  1816. shape_list(head_mask)[0],
  1817. len(self.layers),
  1818. message=(
  1819. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  1820. f" {shape_list(head_mask)[0]}."
  1821. ),
  1822. )
  1823. for idx, decoder_layer in enumerate(self.layers):
  1824. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  1825. if output_hidden_states:
  1826. all_hidden_states += (hidden_states,)
  1827. dropout_probability = random.uniform(0, 1)
  1828. if training and (dropout_probability < self.layerdrop):
  1829. continue
  1830. past_key_value = past_key_values[idx] if past_key_values is not None else None
  1831. hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
  1832. hidden_states,
  1833. attention_mask=combined_attention_mask,
  1834. encoder_hidden_states=encoder_hidden_states,
  1835. encoder_attention_mask=encoder_attention_mask,
  1836. layer_head_mask=head_mask[idx] if head_mask is not None else None,
  1837. encoder_layer_head_mask=encoder_head_mask[idx] if encoder_head_mask is not None else None,
  1838. past_key_value=past_key_value,
  1839. )
  1840. if use_cache:
  1841. present_key_values += (present_key_value,)
  1842. if output_attentions:
  1843. all_self_attns += (layer_self_attn,)
  1844. all_cross_attentions += (layer_cross_attn,)
  1845. if output_hidden_states:
  1846. all_hidden_states += (hidden_states,)
  1847. else:
  1848. all_hidden_states = None
  1849. all_self_attns = all_self_attns if output_attentions else None
  1850. all_cross_attentions = all_cross_attentions if output_attentions else None
  1851. present_key_values = present_key_values if use_cache else None
  1852. if not return_dict:
  1853. return tuple(
  1854. v
  1855. for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
  1856. if v is not None
  1857. )
  1858. else:
  1859. return TFBaseModelOutputWithPastAndCrossAttentions(
  1860. last_hidden_state=hidden_states,
  1861. past_key_values=present_key_values,
  1862. hidden_states=all_hidden_states,
  1863. attentions=all_self_attns,
  1864. cross_attentions=all_cross_attentions,
  1865. )
  1866. def build(self, input_shape=None):
  1867. if self.built:
  1868. return
  1869. self.built = True
  1870. if getattr(self, "embed_positions", None) is not None:
  1871. with tf.name_scope(self.embed_positions.name):
  1872. self.embed_positions.build(None)
  1873. if getattr(self, "layernorm_embedding", None) is not None:
  1874. with tf.name_scope(self.layernorm_embedding.name):
  1875. self.layernorm_embedding.build([None, None, self.config.d_model])
  1876. if getattr(self, "layers", None) is not None:
  1877. for layer in self.layers:
  1878. with tf.name_scope(layer.name):
  1879. layer.build(None)
  1880. @keras_serializable
  1881. class TFLEDMainLayer(keras.layers.Layer):
  1882. config_class = LEDConfig
  1883. def __init__(self, config: LEDConfig, **kwargs):
  1884. super().__init__(**kwargs)
  1885. self.config = config
  1886. self.shared = keras.layers.Embedding(
  1887. input_dim=config.vocab_size,
  1888. output_dim=config.d_model,
  1889. embeddings_initializer=keras.initializers.TruncatedNormal(stddev=self.config.init_std),
  1890. name="led.shared",
  1891. )
  1892. # Additional attribute to specify the expected name scope of the layer (for loading/storing weights)
  1893. self.shared.load_weight_prefix = "led.shared"
  1894. self.encoder = TFLEDEncoder(config, self.shared, name="encoder")
  1895. self.decoder = TFLEDDecoder(config, self.shared, name="decoder")
  1896. def get_input_embeddings(self):
  1897. return self.shared
  1898. def set_input_embeddings(self, new_embeddings):
  1899. self.shared = new_embeddings
  1900. self.encoder.embed_tokens = self.shared
  1901. self.decoder.embed_tokens = self.shared
  1902. @unpack_inputs
  1903. def call(
  1904. self,
  1905. input_ids=None,
  1906. attention_mask=None,
  1907. decoder_input_ids=None,
  1908. decoder_attention_mask=None,
  1909. head_mask=None,
  1910. decoder_head_mask=None,
  1911. encoder_outputs: Optional[Union[Tuple, TFLEDEncoderBaseModelOutput]] = None,
  1912. global_attention_mask=None,
  1913. past_key_values=None,
  1914. inputs_embeds=None,
  1915. decoder_inputs_embeds=None,
  1916. use_cache=None,
  1917. output_attentions=None,
  1918. output_hidden_states=None,
  1919. return_dict=None,
  1920. training=False,
  1921. **kwargs,
  1922. ):
  1923. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1924. use_cache = False
  1925. if encoder_outputs is None:
  1926. encoder_outputs = self.encoder(
  1927. input_ids=input_ids,
  1928. attention_mask=attention_mask,
  1929. global_attention_mask=global_attention_mask,
  1930. head_mask=head_mask,
  1931. inputs_embeds=inputs_embeds,
  1932. output_attentions=output_attentions,
  1933. output_hidden_states=output_hidden_states,
  1934. return_dict=return_dict,
  1935. training=training,
  1936. )
  1937. # If the user passed a tuple for encoder_outputs, we wrap it in a TFLEDEncoderBaseModelOutput when return_dict=True
  1938. elif return_dict and not isinstance(encoder_outputs, TFLEDEncoderBaseModelOutput):
  1939. encoder_outputs = TFLEDEncoderBaseModelOutput(
  1940. last_hidden_state=encoder_outputs[0],
  1941. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1942. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1943. )
  1944. # If the user passed a TFLEDEncoderBaseModelOutput for encoder_outputs, we wrap it in a tuple when return_dict=False
  1945. elif not return_dict and not isinstance(encoder_outputs, tuple):
  1946. encoder_outputs = encoder_outputs.to_tuple()
  1947. decoder_outputs = self.decoder(
  1948. decoder_input_ids,
  1949. attention_mask=decoder_attention_mask,
  1950. encoder_hidden_states=encoder_outputs[0],
  1951. encoder_attention_mask=attention_mask,
  1952. head_mask=decoder_head_mask,
  1953. encoder_head_mask=head_mask,
  1954. past_key_values=past_key_values,
  1955. inputs_embeds=decoder_inputs_embeds,
  1956. use_cache=use_cache,
  1957. output_attentions=output_attentions,
  1958. output_hidden_states=output_hidden_states,
  1959. return_dict=return_dict,
  1960. training=training,
  1961. )
  1962. if not return_dict:
  1963. return decoder_outputs + encoder_outputs
  1964. return TFLEDSeq2SeqModelOutput(
  1965. last_hidden_state=decoder_outputs.last_hidden_state,
  1966. past_key_values=decoder_outputs.past_key_values,
  1967. decoder_hidden_states=decoder_outputs.hidden_states,
  1968. decoder_attentions=decoder_outputs.attentions,
  1969. cross_attentions=decoder_outputs.cross_attentions,
  1970. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1971. encoder_hidden_states=encoder_outputs.hidden_states,
  1972. encoder_attentions=encoder_outputs.attentions,
  1973. encoder_global_attentions=encoder_outputs.global_attentions,
  1974. )
  1975. def build(self, input_shape=None):
  1976. if self.built:
  1977. return
  1978. self.built = True
  1979. # The shared/tied weights expect to be in the model base namespace
  1980. # Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
  1981. # the current one.
  1982. with tf.name_scope(self.shared.load_weight_prefix + "/" + self.shared.name + "/"):
  1983. self.shared.build(None)
  1984. if getattr(self, "encoder", None) is not None:
  1985. with tf.name_scope(self.encoder.name):
  1986. self.encoder.build(None)
  1987. if getattr(self, "decoder", None) is not None:
  1988. with tf.name_scope(self.decoder.name):
  1989. self.decoder.build(None)
  1990. @add_start_docstrings(
  1991. "The bare LED Model outputting raw hidden-states without any specific head on top.",
  1992. LED_START_DOCSTRING,
  1993. )
  1994. class TFLEDModel(TFLEDPreTrainedModel):
  1995. def __init__(self, config, *inputs, **kwargs):
  1996. super().__init__(config, *inputs, **kwargs)
  1997. self.led = TFLEDMainLayer(config, name="led")
  1998. def get_encoder(self):
  1999. return self.led.encoder
  2000. def get_decoder(self):
  2001. return self.led.decoder
  2002. @unpack_inputs
  2003. @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  2004. @add_code_sample_docstrings(
  2005. checkpoint=_CHECKPOINT_FOR_DOC,
  2006. output_type=TFLEDSeq2SeqModelOutput,
  2007. config_class=_CONFIG_FOR_DOC,
  2008. )
  2009. def call(
  2010. self,
  2011. input_ids: TFModelInputType | None = None,
  2012. attention_mask: tf.Tensor | None = None,
  2013. decoder_input_ids: tf.Tensor | None = None,
  2014. decoder_attention_mask: tf.Tensor | None = None,
  2015. head_mask: tf.Tensor | None = None,
  2016. decoder_head_mask: tf.Tensor | None = None,
  2017. encoder_outputs: tf.Tensor | None = None,
  2018. global_attention_mask: tf.Tensor | None = None,
  2019. past_key_values: Tuple[Tuple[tf.Tensor]] | None = None,
  2020. inputs_embeds: tf.Tensor | None = None,
  2021. decoder_inputs_embeds: tf.Tensor | None = None,
  2022. use_cache: bool | None = None,
  2023. output_attentions: bool | None = None,
  2024. output_hidden_states: bool | None = None,
  2025. return_dict: bool | None = None,
  2026. training: bool = False,
  2027. **kwargs,
  2028. ) -> Tuple[tf.Tensor] | TFLEDSeq2SeqModelOutput:
  2029. outputs = self.led(
  2030. input_ids=input_ids,
  2031. attention_mask=attention_mask,
  2032. decoder_input_ids=decoder_input_ids,
  2033. decoder_attention_mask=decoder_attention_mask,
  2034. encoder_outputs=encoder_outputs,
  2035. global_attention_mask=global_attention_mask,
  2036. head_mask=head_mask,
  2037. decoder_head_mask=decoder_head_mask,
  2038. past_key_values=past_key_values,
  2039. inputs_embeds=inputs_embeds,
  2040. decoder_inputs_embeds=decoder_inputs_embeds,
  2041. use_cache=use_cache,
  2042. output_attentions=output_attentions,
  2043. output_hidden_states=output_hidden_states,
  2044. return_dict=return_dict,
  2045. training=training,
  2046. )
  2047. return outputs
  2048. def serving_output(self, output):
  2049. pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
  2050. dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
  2051. dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
  2052. cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
  2053. enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
  2054. enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
  2055. enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None
  2056. return TFLEDSeq2SeqModelOutput(
  2057. last_hidden_state=output.last_hidden_state,
  2058. past_key_values=pkv,
  2059. decoder_hidden_states=dec_hs,
  2060. decoder_attentions=dec_attns,
  2061. cross_attentions=cross_attns,
  2062. encoder_last_hidden_state=output.encoder_last_hidden_state,
  2063. encoder_hidden_states=enc_hs,
  2064. encoder_attentions=enc_attns,
  2065. encoder_global_attentions=enc_g_attns,
  2066. )
  2067. def build(self, input_shape=None):
  2068. if self.built:
  2069. return
  2070. self.built = True
  2071. if getattr(self, "led", None) is not None:
  2072. with tf.name_scope(self.led.name):
  2073. self.led.build(None)
  2074. # Copied from transformers.models.bart.modeling_tf_bart.BiasLayer
  2075. class BiasLayer(keras.layers.Layer):
  2076. """
  2077. Bias as a layer. It is used for serialization purposes: `keras.Model.save_weights` stores on a per-layer basis,
  2078. so all weights have to be registered in a layer.
  2079. """
  2080. def __init__(self, shape, initializer, trainable, name, **kwargs):
  2081. super().__init__(name=name, **kwargs)
  2082. # Note: the name of this variable will NOT be scoped when serialized, i.e. it will not be in the format of
  2083. # "outer_layer/inner_layer/.../name:0". Instead, it will be "name:0". For further details, see:
  2084. # https://github.com/huggingface/transformers/pull/18833#issuecomment-1233090214
  2085. self.bias = self.add_weight(name=name, shape=shape, initializer=initializer, trainable=trainable)
  2086. def call(self, x):
  2087. return x + self.bias
  2088. @add_start_docstrings(
  2089. "The LED Model with a language modeling head. Can be used for summarization.",
  2090. LED_START_DOCSTRING,
  2091. )
  2092. class TFLEDForConditionalGeneration(TFLEDPreTrainedModel):
  2093. _keys_to_ignore_on_load_unexpected = [
  2094. r"led.encoder.embed_tokens.weight",
  2095. r"led.decoder.embed_tokens.weight",
  2096. ]
  2097. def __init__(self, config, *inputs, **kwargs):
  2098. super().__init__(config, *inputs, **kwargs)
  2099. self.led = TFLEDMainLayer(config, name="led")
  2100. self.use_cache = config.use_cache
  2101. # final_bias_logits is registered as a buffer in pytorch, so not trainable for the sake of consistency.
  2102. self.bias_layer = BiasLayer(
  2103. name="final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
  2104. )
  2105. # TODO (Joao): investigate why LED has numerical issues in XLA generate
  2106. self.supports_xla_generation = False
  2107. def get_decoder(self):
  2108. return self.led.decoder
  2109. def get_encoder(self):
  2110. return self.led.encoder
  2111. def get_bias(self):
  2112. return {"final_logits_bias": self.bias_layer.bias}
  2113. def set_bias(self, value):
  2114. # Replaces the existing layers containing bias for correct (de)serialization.
  2115. vocab_size = value["final_logits_bias"].shape[-1]
  2116. self.bias_layer = BiasLayer(
  2117. name="final_logits_bias", shape=[1, vocab_size], initializer="zeros", trainable=False
  2118. )
  2119. self.bias_layer.bias.assign(value["final_logits_bias"])
  2120. def get_output_embeddings(self):
  2121. return self.get_input_embeddings()
  2122. def set_output_embeddings(self, value):
  2123. self.set_input_embeddings(value)
  2124. @unpack_inputs
  2125. @add_start_docstrings_to_model_forward(LED_INPUTS_DOCSTRING)
  2126. @replace_return_docstrings(output_type=TFLEDSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  2127. def call(
  2128. self,
  2129. input_ids: TFModelInputType | None = None,
  2130. attention_mask: np.ndarray | tf.Tensor | None = None,
  2131. decoder_input_ids: np.ndarray | tf.Tensor | None = None,
  2132. decoder_attention_mask: np.ndarray | tf.Tensor | None = None,
  2133. head_mask: np.ndarray | tf.Tensor | None = None,
  2134. decoder_head_mask: np.ndarray | tf.Tensor | None = None,
  2135. encoder_outputs: TFLEDEncoderBaseModelOutput | None = None,
  2136. global_attention_mask: np.ndarray | tf.Tensor | None = None,
  2137. past_key_values: Tuple[Tuple[Union[np.ndarray, tf.Tensor]]] | None = None,
  2138. inputs_embeds: np.ndarray | tf.Tensor | None = None,
  2139. decoder_inputs_embeds: np.ndarray | tf.Tensor | None = None,
  2140. use_cache: bool | None = None,
  2141. output_attentions: bool | None = None,
  2142. output_hidden_states: bool | None = None,
  2143. return_dict: bool | None = None,
  2144. labels: tf.Tensor | None = None,
  2145. training: bool = False,
  2146. ) -> Tuple[tf.Tensor] | TFLEDSeq2SeqLMOutput:
  2147. """
  2148. Returns:
  2149. Examples:
  2150. ```python
  2151. >>> from transformers import AutoTokenizer, TFLEDForConditionalGeneration
  2152. >>> import tensorflow as tf
  2153. >>> mname = "allenai/led-base-16384"
  2154. >>> tokenizer = AutoTokenizer.from_pretrained(mname)
  2155. >>> TXT = "My friends are <mask> but they eat too many carbs."
  2156. >>> model = TFLEDForConditionalGeneration.from_pretrained(mname)
  2157. >>> batch = tokenizer([TXT], return_tensors="tf")
  2158. >>> logits = model(inputs=batch.input_ids).logits
  2159. >>> probs = tf.nn.softmax(logits[0])
  2160. >>> # probs[5] is associated with the mask token
  2161. ```"""
  2162. if labels is not None:
  2163. use_cache = False
  2164. if decoder_input_ids is None and decoder_inputs_embeds is None:
  2165. decoder_input_ids = shift_tokens_right(
  2166. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  2167. )
  2168. outputs = self.led(
  2169. input_ids,
  2170. attention_mask=attention_mask,
  2171. decoder_input_ids=decoder_input_ids,
  2172. decoder_attention_mask=decoder_attention_mask,
  2173. encoder_outputs=encoder_outputs,
  2174. global_attention_mask=global_attention_mask,
  2175. head_mask=head_mask,
  2176. decoder_head_mask=decoder_head_mask,
  2177. past_key_values=past_key_values,
  2178. inputs_embeds=inputs_embeds,
  2179. decoder_inputs_embeds=decoder_inputs_embeds,
  2180. use_cache=use_cache,
  2181. output_attentions=output_attentions,
  2182. output_hidden_states=output_hidden_states,
  2183. return_dict=return_dict,
  2184. training=training,
  2185. )
  2186. lm_logits = tf.matmul(outputs[0], self.led.shared.weights, transpose_b=True)
  2187. lm_logits = self.bias_layer(lm_logits)
  2188. masked_lm_loss = None if labels is None else self.hf_compute_loss(labels, lm_logits)
  2189. if not return_dict:
  2190. output = (lm_logits,) + outputs[1:]
  2191. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  2192. return TFLEDSeq2SeqLMOutput(
  2193. loss=masked_lm_loss,
  2194. logits=lm_logits,
  2195. past_key_values=outputs.past_key_values, # index 1 of d outputs
  2196. decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
  2197. decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
  2198. cross_attentions=outputs.cross_attentions, # index 4 of d outputs
  2199. encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
  2200. encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
  2201. encoder_attentions=outputs.encoder_attentions, # 2 of e out
  2202. encoder_global_attentions=outputs.encoder_global_attentions,
  2203. )
  2204. def serving_output(self, output):
  2205. pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
  2206. dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None
  2207. dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None
  2208. cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None
  2209. enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None
  2210. enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None
  2211. enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None
  2212. return TFLEDSeq2SeqLMOutput(
  2213. logits=output.logits,
  2214. past_key_values=pkv,
  2215. decoder_hidden_states=dec_hs,
  2216. decoder_attentions=dec_attns,
  2217. cross_attentions=cross_attns,
  2218. encoder_last_hidden_state=output.encoder_last_hidden_state,
  2219. encoder_hidden_states=enc_hs,
  2220. encoder_attentions=enc_attns,
  2221. encoder_global_attentions=enc_g_attns,
  2222. )
  2223. def prepare_inputs_for_generation(
  2224. self,
  2225. decoder_input_ids,
  2226. past_key_values=None,
  2227. attention_mask=None,
  2228. head_mask=None,
  2229. decoder_head_mask=None,
  2230. use_cache=None,
  2231. encoder_outputs=None,
  2232. **kwargs,
  2233. ):
  2234. # cut decoder_input_ids if past is used
  2235. if past_key_values is not None:
  2236. decoder_input_ids = decoder_input_ids[:, -1:]
  2237. return {
  2238. "input_ids": None, # encoder_outputs is defined. input_ids not needed
  2239. "encoder_outputs": encoder_outputs,
  2240. "past_key_values": past_key_values,
  2241. "decoder_input_ids": decoder_input_ids,
  2242. "attention_mask": attention_mask,
  2243. "head_mask": head_mask,
  2244. "decoder_head_mask": decoder_head_mask,
  2245. "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
  2246. }
  2247. def prepare_decoder_input_ids_from_labels(self, labels: tf.Tensor):
  2248. return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
  2249. def hf_compute_loss(self, labels, logits):
  2250. """CrossEntropyLoss that ignores pad tokens"""
  2251. loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=keras.losses.Reduction.NONE)
  2252. if self.config.tf_legacy_loss:
  2253. melted_labels = tf.reshape(labels, (-1,))
  2254. active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
  2255. reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
  2256. labels = tf.boolean_mask(melted_labels, active_loss)
  2257. return loss_fn(labels, reduced_logits)
  2258. # Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
  2259. unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
  2260. # make sure only non-padding labels affect the loss
  2261. loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
  2262. masked_loss = unmasked_loss * loss_mask
  2263. reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
  2264. return tf.reshape(reduced_masked_loss, (1,))
  2265. def build(self, input_shape=None):
  2266. if self.built:
  2267. return
  2268. self.built = True
  2269. if getattr(self, "led", None) is not None:
  2270. with tf.name_scope(self.led.name):
  2271. self.led.build(None)
  2272. if getattr(self, "bias_layer", None) is not None:
  2273. with tf.name_scope(self.bias_layer.name):
  2274. self.bias_layer.build(None)