modeling_flax_longt5.py 103 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446
  1. # coding=utf-8
  2. # Copyright 2022 LongT5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Flax LongT5 model."""
  16. import copy
  17. from typing import Any, Callable, List, Optional, Tuple
  18. import flax.linen as nn
  19. import jax
  20. import jax.numpy as jnp
  21. import numpy as np
  22. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  23. from flax.linen import combine_masks, make_causal_mask
  24. from flax.linen import partitioning as nn_partitioning
  25. from flax.linen.attention import dot_product_attention_weights
  26. from flax.traverse_util import flatten_dict, unflatten_dict
  27. from jax.random import PRNGKey
  28. from ...modeling_flax_outputs import (
  29. FlaxBaseModelOutput,
  30. FlaxBaseModelOutputWithPastAndCrossAttentions,
  31. FlaxCausalLMOutputWithCrossAttentions,
  32. FlaxSeq2SeqLMOutput,
  33. FlaxSeq2SeqModelOutput,
  34. )
  35. from ...modeling_flax_utils import (
  36. ACT2FN,
  37. FlaxPreTrainedModel,
  38. append_call_sample_docstring,
  39. append_replace_return_docstrings,
  40. overwrite_call_docstring,
  41. )
  42. from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
  43. from .configuration_longt5 import LongT5Config
  44. logger = logging.get_logger(__name__)
  45. _CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
  46. _CONFIG_FOR_DOC = "LongT5Config"
  47. remat = nn_partitioning.remat
  48. # Copied from transformers.models.bart.modeling_flax_bart.shift_tokens_right
  49. def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
  50. """
  51. Shift input ids one token to the right.
  52. """
  53. shifted_input_ids = jnp.zeros_like(input_ids)
  54. shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
  55. shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
  56. shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
  57. return shifted_input_ids
  58. def _pad_to_multiple(x: jnp.ndarray, block_len: int, axis: int, pad_value: int = 0) -> jnp.ndarray:
  59. """Pad an array so that a sequence length will be a multiple of `block_len`"""
  60. pad_len = -x.shape[axis] % block_len
  61. pad = [(0, 0)] * x.ndim
  62. pad[axis] = (0, pad_len)
  63. x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
  64. return x
  65. def _split_into_blocks(x: jnp.ndarray, block_len: int, axis: int) -> jnp.ndarray:
  66. """Split an input array into blocks of a given `block_len` along the given `axis`. If the dimension length
  67. is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
  68. """
  69. # pad tensor to multiple of block_len
  70. if x.shape[axis] % block_len != 0:
  71. x = _pad_to_multiple(x, block_len, axis, pad_value=0)
  72. num_blocks = x.shape[axis] // block_len
  73. output_shape = x.shape[:axis] + (num_blocks, block_len) + x.shape[(axis + 1) :]
  74. return x.reshape(output_shape)
  75. def _concatenate_3_blocks(x: jnp.ndarray, block_axis: int, sequence_axis: int, pad_value: int = 0) -> jnp.ndarray:
  76. """Concatenate three consecutive blocks for each input block for local attentiont.
  77. For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
  78. """
  79. num_blocks = x.shape[block_axis]
  80. pad = [(0, 0)] * x.ndim
  81. pad[block_axis] = (1, 1)
  82. # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
  83. x = jnp.pad(x, pad_width=pad, mode="constant", constant_values=pad_value)
  84. blocks_list: List[np.array] = []
  85. for i in range(3):
  86. # We use indexing approach here:
  87. # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
  88. indices = [slice(0, None)] * x.ndim
  89. indices[block_axis] = slice(i, i + num_blocks)
  90. indices = tuple(indices)
  91. blocks_list.append(x[indices])
  92. return jnp.concatenate(blocks_list, axis=sequence_axis) # [batch_size, num_blocks, 3 * block_len, ...]
  93. def _make_3block_relative_position_ids(block_len: int) -> jnp.ndarray:
  94. """Makes 3-blocked relative position ids for local attention."""
  95. position_ids = jnp.arange(3 * block_len, dtype=jnp.int32)
  96. center_position_ids = position_ids[block_len:-block_len]
  97. relative_position_ids = position_ids[None, :] - center_position_ids[:, None] # [block_len, 3 * block_len]
  98. return relative_position_ids
  99. def _mask_local_attention_mask(local_attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
  100. """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
  101. relative_position_ids = _make_3block_relative_position_ids(block_len)
  102. locality_mask = jnp.abs(relative_position_ids) < block_len
  103. locality_mask = locality_mask[None, None, :, :]
  104. return jnp.logical_and(local_attention_mask, locality_mask)
  105. def _get_local_attention_mask(attention_mask: np.ndarray, block_len: int) -> jnp.ndarray:
  106. """Prepare attention mask to be applied for a local attention."""
  107. # [batch_size, num_blocks, block_len]
  108. _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, axis=1)
  109. # [batch_size, num_block, 3 * block_len]
  110. _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_axis=1, sequence_axis=2)
  111. _blocked_attention_mask = _blocked_attention_mask[..., None]
  112. _3blocked_attention_mask = _3blocked_attention_mask[..., None, :]
  113. # [batch_size, num_block, block_len, 3 * block_len]
  114. local_attention_mask = jnp.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
  115. local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
  116. # [batch_size, 1, num_block, block_len, 3 * block_len]
  117. return local_attention_mask[:, None, ...]
  118. def _make_global_fixed_block_ids(attention_mask: np.ndarray, global_block_size: int) -> Tuple[jnp.ndarray, np.ndarray]:
  119. """Obtain the "fixed block" global id corresponding to each input token.
  120. This implementation is a simlified version of the original Flaxformr implementation adopted from:
  121. https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
  122. In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
  123. the whole fixed block, are assigned to the preceding block.
  124. Padding tokens from the original sequence are represented by -1.
  125. """
  126. batch_size, seq_len = attention_mask.shape[:2]
  127. def handle_orphan_tokens(block_ids: np.ndarray) -> jnp.ndarray:
  128. block_ends = (jnp.arange(seq_len) % global_block_size) == global_block_size - 1
  129. true_block_ends = jnp.logical_and(block_ends, block_ids >= 0)
  130. full_blocks = true_block_ends.sum(-1)[..., None]
  131. block_ids = jnp.minimum(block_ids, full_blocks - 1)
  132. return block_ids
  133. fixed_block_mask = jnp.ones_like(attention_mask) / global_block_size
  134. fixed_block_mask = jnp.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
  135. mask = jnp.where(attention_mask != 0.0, 1.0, -1000.0)
  136. global_block_ids = jnp.maximum(
  137. jnp.floor(mask + fixed_block_mask - 1.0), jnp.array(-1.0, dtype=attention_mask.dtype)
  138. )
  139. # set padding tokens to -1
  140. global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
  141. # [batch_size, seq_len]
  142. global_block_ids = handle_orphan_tokens(global_block_ids)
  143. num_globals = seq_len // global_block_size
  144. # [batch_size, seq_len // global_block_size]
  145. if num_globals > 0:
  146. _sequence_block_ids_max = jnp.repeat(global_block_ids.max(axis=-1)[:, None], repeats=num_globals, axis=1)
  147. else:
  148. _sequence_block_ids_max = jnp.zeros((batch_size, 0), dtype=global_block_ids.dtype)
  149. global_segment_ids = jnp.cumsum(jnp.ones((batch_size, num_globals)), axis=-1) - 1
  150. global_segment_ids = jnp.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
  151. return global_block_ids, global_segment_ids
  152. def _make_side_relative_position_ids(attention_mask: np.ndarray, global_block_size: int) -> np.ndarray:
  153. """Create the relative position tensor for local -> global attention."""
  154. block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
  155. global_seq_len = global_segment_ids.shape[-1]
  156. global_positions = jnp.arange(global_seq_len)
  157. side_relative_position = global_positions - block_ids[..., None]
  158. return side_relative_position
  159. def _create_global_aggregates(hidden_states: np.ndarray, block_ids: np.ndarray, global_seq_len: int) -> np.ndarray:
  160. """Compute individual block aggregates by summing over individual blocks."""
  161. # (batch..., seq_len, global_seq_len))
  162. one_hot_block_ids = jax.nn.one_hot(block_ids, global_seq_len)
  163. return jnp.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids)
  164. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerNorm with T5->LongT5
  165. class FlaxLongT5LayerNorm(nn.Module):
  166. hidden_size: int
  167. dtype: jnp.dtype = jnp.float32
  168. eps: float = 1e-6
  169. weight_init: Callable[..., np.ndarray] = jax.nn.initializers.ones
  170. def setup(self):
  171. self.weight = self.param("weight", self.weight_init, (self.hidden_size,))
  172. def __call__(self, hidden_states):
  173. """
  174. Construct a layernorm module in the LongT5 style; No bias and no subtraction of mean.
  175. """
  176. # layer norm should always be calculated in float32
  177. variance = jnp.power(hidden_states.astype("f4"), 2).mean(axis=-1, keepdims=True)
  178. hidden_states = hidden_states / jnp.sqrt(variance + self.eps)
  179. return self.weight * hidden_states
  180. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseActDense with T5->LongT5
  181. class FlaxLongT5DenseActDense(nn.Module):
  182. config: LongT5Config
  183. dtype: jnp.dtype = jnp.float32
  184. def setup(self):
  185. wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
  186. wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
  187. self.wi = nn.Dense(
  188. self.config.d_ff,
  189. use_bias=False,
  190. kernel_init=jax.nn.initializers.normal(wi_init_std),
  191. dtype=self.dtype,
  192. )
  193. self.wo = nn.Dense(
  194. self.config.d_model,
  195. use_bias=False,
  196. kernel_init=jax.nn.initializers.normal(wo_init_std),
  197. dtype=self.dtype,
  198. )
  199. self.dropout = nn.Dropout(self.config.dropout_rate)
  200. self.act = ACT2FN[self.config.dense_act_fn]
  201. def __call__(self, hidden_states, deterministic=True):
  202. hidden_states = self.wi(hidden_states)
  203. hidden_states = self.act(hidden_states)
  204. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  205. hidden_states = self.wo(hidden_states)
  206. return hidden_states
  207. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5DenseGatedActDense with T5->LongT5
  208. class FlaxLongT5DenseGatedActDense(nn.Module):
  209. config: LongT5Config
  210. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  211. def setup(self):
  212. wi_init_std = self.config.initializer_factor * (self.config.d_model**-0.5)
  213. wo_init_std = self.config.initializer_factor * (self.config.d_ff**-0.5)
  214. self.wi_0 = nn.Dense(
  215. self.config.d_ff,
  216. use_bias=False,
  217. kernel_init=jax.nn.initializers.normal(wi_init_std),
  218. dtype=self.dtype,
  219. )
  220. self.wi_1 = nn.Dense(
  221. self.config.d_ff,
  222. use_bias=False,
  223. kernel_init=jax.nn.initializers.normal(wi_init_std),
  224. dtype=self.dtype,
  225. )
  226. self.wo = nn.Dense(
  227. self.config.d_model,
  228. use_bias=False,
  229. kernel_init=jax.nn.initializers.normal(wo_init_std),
  230. dtype=self.dtype,
  231. )
  232. self.dropout = nn.Dropout(self.config.dropout_rate)
  233. self.act = ACT2FN[self.config.dense_act_fn]
  234. def __call__(self, hidden_states, deterministic):
  235. hidden_gelu = self.act(self.wi_0(hidden_states))
  236. hidden_linear = self.wi_1(hidden_states)
  237. hidden_states = hidden_gelu * hidden_linear
  238. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  239. hidden_states = self.wo(hidden_states)
  240. return hidden_states
  241. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerFF with T5->LongT5
  242. class FlaxLongT5LayerFF(nn.Module):
  243. config: LongT5Config
  244. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  245. def setup(self):
  246. if self.config.is_gated_act:
  247. self.DenseReluDense = FlaxLongT5DenseGatedActDense(self.config, dtype=self.dtype)
  248. else:
  249. self.DenseReluDense = FlaxLongT5DenseActDense(self.config, dtype=self.dtype)
  250. self.layer_norm = FlaxLongT5LayerNorm(
  251. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  252. )
  253. self.dropout = nn.Dropout(self.config.dropout_rate)
  254. def __call__(self, hidden_states, deterministic=True):
  255. forwarded_states = self.layer_norm(hidden_states)
  256. forwarded_states = self.DenseReluDense(forwarded_states, deterministic=deterministic)
  257. hidden_states = hidden_states + self.dropout(forwarded_states, deterministic=deterministic)
  258. return hidden_states
  259. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention with T5->LongT5
  260. class FlaxLongT5Attention(nn.Module):
  261. config: LongT5Config
  262. has_relative_attention_bias: bool = False
  263. causal: bool = False
  264. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  265. def setup(self):
  266. self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
  267. self.relative_attention_max_distance = self.config.relative_attention_max_distance
  268. self.d_model = self.config.d_model
  269. self.key_value_proj_dim = self.config.d_kv
  270. self.n_heads = self.config.num_heads
  271. self.dropout = self.config.dropout_rate
  272. self.inner_dim = self.n_heads * self.key_value_proj_dim
  273. q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
  274. kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  275. o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  276. self.q = nn.Dense(
  277. self.inner_dim,
  278. use_bias=False,
  279. kernel_init=jax.nn.initializers.normal(q_init_std),
  280. dtype=self.dtype,
  281. )
  282. self.k = nn.Dense(
  283. self.inner_dim,
  284. use_bias=False,
  285. kernel_init=jax.nn.initializers.normal(kv_init_std),
  286. dtype=self.dtype,
  287. )
  288. self.v = nn.Dense(
  289. self.inner_dim,
  290. use_bias=False,
  291. kernel_init=jax.nn.initializers.normal(kv_init_std),
  292. dtype=self.dtype,
  293. )
  294. self.o = nn.Dense(
  295. self.d_model,
  296. use_bias=False,
  297. kernel_init=jax.nn.initializers.normal(o_init_std),
  298. dtype=self.dtype,
  299. )
  300. if self.has_relative_attention_bias:
  301. self.relative_attention_bias = nn.Embed(
  302. self.relative_attention_num_buckets,
  303. self.n_heads,
  304. embedding_init=jax.nn.initializers.normal(kv_init_std),
  305. dtype=self.dtype,
  306. )
  307. @staticmethod
  308. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  309. """
  310. Adapted from Mesh Tensorflow:
  311. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  312. Translate relative position to a bucket number for relative attention. The relative position is defined as
  313. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  314. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  315. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  316. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  317. This should allow for more graceful generalization to longer sequences than the model has been trained on
  318. """
  319. relative_buckets = 0
  320. if bidirectional:
  321. num_buckets //= 2
  322. relative_buckets += (relative_position > 0) * num_buckets
  323. relative_position = jnp.abs(relative_position)
  324. else:
  325. relative_position = -jnp.clip(relative_position, a_max=0)
  326. # now relative_position is in the range [0, inf)
  327. # half of the buckets are for exact increments in positions
  328. max_exact = num_buckets // 2
  329. is_small = relative_position < max_exact
  330. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  331. relative_position_if_large = max_exact + (
  332. jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
  333. )
  334. relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
  335. relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
  336. return relative_buckets.astype("i4")
  337. def compute_bias(self, query_length, key_length):
  338. """Compute binned relative position bias"""
  339. context_position = jnp.arange(query_length, dtype="i4")[:, None]
  340. memory_position = jnp.arange(key_length, dtype="i4")[None, :]
  341. relative_position = memory_position - context_position
  342. relative_position_bucket = self._relative_position_bucket(
  343. relative_position,
  344. bidirectional=(not self.causal),
  345. num_buckets=self.relative_attention_num_buckets,
  346. max_distance=self.relative_attention_max_distance,
  347. )
  348. values = self.relative_attention_bias(relative_position_bucket)
  349. values = values.transpose((2, 0, 1))[None, :, :, :]
  350. return values
  351. def _split_heads(self, hidden_states):
  352. return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
  353. def _merge_heads(self, hidden_states):
  354. return hidden_states.reshape(hidden_states.shape[:2] + (self.inner_dim,))
  355. @nn.compact
  356. def _concatenate_to_cache(self, key, value, query, attention_mask):
  357. """
  358. This function takes projected key, value states from a single input token and concatenates the states to cached
  359. states from previous steps. This function is slighly adapted from the official Flax repository:
  360. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  361. """
  362. # detect if we're initializing by absence of existing cache data.
  363. is_initialized = self.has_variable("cache", "cached_key")
  364. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  365. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  366. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  367. if is_initialized:
  368. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  369. # update key, value caches with our new 1d spatial slices
  370. cur_index = cache_index.value
  371. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  372. key = jax.lax.dynamic_update_slice(cached_key.value, key, indices)
  373. value = jax.lax.dynamic_update_slice(cached_value.value, value, indices)
  374. cached_key.value = key
  375. cached_value.value = value
  376. num_updated_cache_vectors = query.shape[1]
  377. cache_index.value = cache_index.value + num_updated_cache_vectors
  378. # causal mask for cached decoder self-attention: our single query position should only attend to those key positions
  379. # that have already been generated and cached, not the remaining zero elements.
  380. pad_mask = jnp.broadcast_to(
  381. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  382. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  383. )
  384. attention_mask = combine_masks(pad_mask, attention_mask)
  385. return key, value, attention_mask
  386. def _create_position_bias(
  387. self, key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
  388. ):
  389. cache_is_filled = self.causal and self.has_variable("cache", "cached_key") and (not init_cache)
  390. key_length = key_states.shape[1]
  391. query_length = key_length if cache_is_filled else query_states.shape[1]
  392. if self.has_relative_attention_bias:
  393. position_bias = self.compute_bias(query_length, key_length)
  394. elif attention_mask is not None:
  395. position_bias = jnp.zeros_like(attention_mask)
  396. else:
  397. position_bias = jnp.zeros((1, self.n_heads, query_length, key_length), dtype=self.dtype)
  398. # if key and values are already calculated, only the last query position bias should be taken
  399. if cache_is_filled:
  400. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  401. position_bias = jax.lax.dynamic_slice(
  402. position_bias,
  403. (0, 0, causal_attention_mask_shift, 0),
  404. (1, self.n_heads, seq_length, max_decoder_length),
  405. )
  406. return position_bias
  407. def __call__(
  408. self,
  409. hidden_states,
  410. attention_mask=None,
  411. key_value_states=None,
  412. position_bias=None,
  413. use_cache=False,
  414. output_attentions=False,
  415. deterministic=True,
  416. init_cache=False,
  417. ):
  418. """
  419. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  420. """
  421. batch_size, seq_length = hidden_states.shape[:2]
  422. # q, k, v projections
  423. query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
  424. key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
  425. value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
  426. # reshape to (batch_size, seq_length, n_heads, head_dim)
  427. query_states = self._split_heads(query_states)
  428. key_states = self._split_heads(key_states)
  429. value_states = self._split_heads(value_states)
  430. # counter-act scaling in dot_product_attention_weights function
  431. query_states *= jnp.sqrt(query_states.shape[-1])
  432. # for fast decoding causal attention mask should be shifted
  433. causal_attention_mask_shift = (
  434. self.variables["cache"]["cache_index"] if (self.has_variable("cache", "cached_key") and self.causal) else 0
  435. )
  436. # create causal attention_mask; attention_mask has to be defined when model is causal
  437. if self.causal:
  438. causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
  439. # fast decoding for generate requires special attention_mask
  440. if self.has_variable("cache", "cached_key"):
  441. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  442. causal_attention_mask = jax.lax.dynamic_slice(
  443. causal_attention_mask,
  444. (0, 0, causal_attention_mask_shift, 0),
  445. (1, 1, seq_length, max_decoder_length),
  446. )
  447. # broadcast causal attention mask & attention mask to fit for merge
  448. causal_attention_mask = jnp.broadcast_to(
  449. causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
  450. )
  451. attention_mask = jnp.broadcast_to(
  452. jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape
  453. )
  454. attention_mask = combine_masks(attention_mask, causal_attention_mask)
  455. elif attention_mask is not None:
  456. attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
  457. # During fast autoregressive decoding, we feed one position at a time,
  458. # and cache the keys and values step by step.
  459. if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
  460. key_states, value_states, attention_mask = self._concatenate_to_cache(
  461. key_states, value_states, query_states, attention_mask
  462. )
  463. # replace masked positions with -10_000
  464. if attention_mask is not None:
  465. mask_value = jnp.finfo(self.dtype).min
  466. attention_mask = jax.lax.select(
  467. attention_mask > 0,
  468. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  469. jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
  470. )
  471. if position_bias is None:
  472. # compute position bias (only for first layer)
  473. position_bias = self._create_position_bias(
  474. key_states, query_states, attention_mask, init_cache, seq_length, causal_attention_mask_shift
  475. )
  476. if attention_mask is not None:
  477. position_bias = position_bias + attention_mask
  478. # create dropout rng
  479. dropout_rng = None
  480. if not deterministic and self.dropout > 0.0:
  481. dropout_rng = self.make_rng("dropout")
  482. # Softmax(QK^T)
  483. attn_weights = dot_product_attention_weights(
  484. query_states,
  485. key_states,
  486. bias=position_bias,
  487. dropout_rng=dropout_rng,
  488. dropout_rate=self.dropout,
  489. broadcast_dropout=True,
  490. deterministic=deterministic,
  491. dtype=self.dtype,
  492. )
  493. # multiply with value states
  494. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  495. # bring back to (batch_size, seq_length, d_model)
  496. attn_output = self._merge_heads(attn_output)
  497. # apply output matrix
  498. attn_output = self.o(attn_output)
  499. outputs = (attn_output, position_bias)
  500. if output_attentions:
  501. outputs = outputs + (attn_weights,)
  502. return outputs
  503. class FlaxLongT5LocalAttention(nn.Module):
  504. config: LongT5Config
  505. has_relative_attention_bias: bool = False
  506. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  507. def setup(self):
  508. self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
  509. self.relative_attention_max_distance = self.config.relative_attention_max_distance
  510. self.d_model = self.config.d_model
  511. self.key_value_proj_dim = self.config.d_kv
  512. self.n_heads = self.config.num_heads
  513. self.local_radius = self.config.local_radius
  514. self.block_len = self.local_radius + 1
  515. self.dropout = self.config.dropout_rate
  516. self.inner_dim = self.n_heads * self.key_value_proj_dim
  517. q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
  518. kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  519. o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  520. self.q = nn.Dense(
  521. self.inner_dim,
  522. use_bias=False,
  523. kernel_init=jax.nn.initializers.normal(q_init_std),
  524. dtype=self.dtype,
  525. )
  526. self.k = nn.Dense(
  527. self.inner_dim,
  528. use_bias=False,
  529. kernel_init=jax.nn.initializers.normal(kv_init_std),
  530. dtype=self.dtype,
  531. )
  532. self.v = nn.Dense(
  533. self.inner_dim,
  534. use_bias=False,
  535. kernel_init=jax.nn.initializers.normal(kv_init_std),
  536. dtype=self.dtype,
  537. )
  538. self.o = nn.Dense(
  539. self.d_model,
  540. use_bias=False,
  541. kernel_init=jax.nn.initializers.normal(o_init_std),
  542. dtype=self.dtype,
  543. )
  544. if self.has_relative_attention_bias:
  545. self.relative_attention_bias = nn.Embed(
  546. self.relative_attention_num_buckets,
  547. self.n_heads,
  548. embedding_init=jax.nn.initializers.normal(kv_init_std),
  549. )
  550. @staticmethod
  551. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
  552. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  553. """
  554. Adapted from Mesh Tensorflow:
  555. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  556. Translate relative position to a bucket number for relative attention. The relative position is defined as
  557. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  558. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  559. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  560. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  561. This should allow for more graceful generalization to longer sequences than the model has been trained on
  562. """
  563. relative_buckets = 0
  564. if bidirectional:
  565. num_buckets //= 2
  566. relative_buckets += (relative_position > 0) * num_buckets
  567. relative_position = jnp.abs(relative_position)
  568. else:
  569. relative_position = -jnp.clip(relative_position, a_max=0)
  570. # now relative_position is in the range [0, inf)
  571. # half of the buckets are for exact increments in positions
  572. max_exact = num_buckets // 2
  573. is_small = relative_position < max_exact
  574. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  575. relative_position_if_large = max_exact + (
  576. jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
  577. )
  578. relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
  579. relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
  580. return relative_buckets.astype("i4")
  581. def compute_bias(self, block_length: int):
  582. """Compute binned relative position bias"""
  583. memory_position = jnp.arange(3 * block_length, dtype="i4")
  584. context_position = memory_position[block_length:-block_length]
  585. relative_position = memory_position[None, :] - context_position[:, None]
  586. relative_position_bucket = self._relative_position_bucket(
  587. relative_position,
  588. bidirectional=True,
  589. num_buckets=self.relative_attention_num_buckets,
  590. max_distance=self.relative_attention_max_distance,
  591. )
  592. values = self.relative_attention_bias(relative_position_bucket)
  593. values = values.transpose((2, 0, 1))[None, None, :, :, :]
  594. return values
  595. def _split_heads(self, hidden_states):
  596. return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
  597. def _merge_heads(self, hidden_states):
  598. return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
  599. def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
  600. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  601. if self.has_relative_attention_bias:
  602. position_bias = self.compute_bias(block_len)
  603. elif attention_mask is not None:
  604. position_bias = jnp.zeros_like(attention_mask)
  605. else:
  606. position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
  607. return position_bias
  608. def __call__(
  609. self,
  610. hidden_states,
  611. attention_mask=None,
  612. key_value_states=None,
  613. position_bias=None,
  614. output_attentions=False,
  615. deterministic=True,
  616. ):
  617. """
  618. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  619. """
  620. batch_size, seq_length = hidden_states.shape[:2]
  621. # q, k, v projections
  622. query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
  623. key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
  624. value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
  625. # reshape to (batch_size, seq_length, n_heads, head_dim)
  626. query_states = self._split_heads(query_states)
  627. key_states = self._split_heads(key_states)
  628. value_states = self._split_heads(value_states)
  629. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
  630. query_states = _split_into_blocks(query_states, self.block_len, axis=1)
  631. key_states = _split_into_blocks(key_states, self.block_len, axis=1)
  632. value_states = _split_into_blocks(value_states, self.block_len, axis=1)
  633. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  634. key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
  635. value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
  636. # counter-act scaling in dot_product_attention_weights function
  637. query_states *= jnp.sqrt(query_states.shape[-1])
  638. if attention_mask is not None:
  639. attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
  640. # replace masked positions with -10_000
  641. attention_mask = jax.lax.select(
  642. attention_mask > 0,
  643. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  644. jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
  645. )
  646. if position_bias is None:
  647. # compute position bias (only for first layer)
  648. position_bias = self._create_position_bias(self.block_len, attention_mask)
  649. if attention_mask is not None:
  650. position_bias = position_bias + attention_mask.swapaxes(1, 2)
  651. # create dropout rng
  652. dropout_rng = None
  653. if not deterministic and self.dropout > 0.0:
  654. dropout_rng = self.make_rng("dropout")
  655. # Softmax(QK^T)
  656. attn_weights = dot_product_attention_weights(
  657. query_states,
  658. key_states,
  659. bias=position_bias,
  660. dropout_rng=dropout_rng,
  661. dropout_rate=self.dropout,
  662. broadcast_dropout=True,
  663. deterministic=deterministic,
  664. dtype=self.dtype,
  665. )
  666. # multiply with value states
  667. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  668. # bring back to (batch_size, seq_length, d_model)
  669. attn_output = self._merge_heads(attn_output)
  670. attn_output = attn_output[:, :seq_length, :]
  671. # apply output matrix
  672. attn_output = self.o(attn_output)
  673. outputs = (attn_output, position_bias)
  674. if output_attentions:
  675. outputs = outputs + (attn_weights,)
  676. return outputs
  677. class FlaxLongT5TransientGlobalAttention(nn.Module):
  678. config: LongT5Config
  679. has_relative_attention_bias: bool = False
  680. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  681. def setup(self):
  682. self.relative_attention_num_buckets = self.config.relative_attention_num_buckets
  683. self.relative_attention_max_distance = self.config.relative_attention_max_distance
  684. self.d_model = self.config.d_model
  685. self.key_value_proj_dim = self.config.d_kv
  686. self.n_heads = self.config.num_heads
  687. self.local_radius = self.config.local_radius
  688. self.block_len = self.local_radius + 1
  689. self.global_block_size = self.config.global_block_size
  690. self.dropout = self.config.dropout_rate
  691. self.inner_dim = self.n_heads * self.key_value_proj_dim
  692. q_init_std = self.config.initializer_factor * ((self.inner_dim * self.key_value_proj_dim) ** -0.5)
  693. kv_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  694. o_init_std = self.config.initializer_factor * (self.inner_dim**-0.5)
  695. self.q = nn.Dense(
  696. self.inner_dim,
  697. use_bias=False,
  698. kernel_init=jax.nn.initializers.normal(q_init_std),
  699. dtype=self.dtype,
  700. )
  701. self.k = nn.Dense(
  702. self.inner_dim,
  703. use_bias=False,
  704. kernel_init=jax.nn.initializers.normal(kv_init_std),
  705. dtype=self.dtype,
  706. )
  707. self.v = nn.Dense(
  708. self.inner_dim,
  709. use_bias=False,
  710. kernel_init=jax.nn.initializers.normal(kv_init_std),
  711. dtype=self.dtype,
  712. )
  713. self.o = nn.Dense(
  714. self.d_model,
  715. use_bias=False,
  716. kernel_init=jax.nn.initializers.normal(o_init_std),
  717. dtype=self.dtype,
  718. )
  719. if self.has_relative_attention_bias:
  720. self.relative_attention_bias = nn.Embed(
  721. self.relative_attention_num_buckets,
  722. self.n_heads,
  723. embedding_init=jax.nn.initializers.normal(kv_init_std),
  724. )
  725. # Relativen attention bias & Layer norm for global attention
  726. if self.has_relative_attention_bias:
  727. self.global_relative_attention_bias = nn.Embed(
  728. self.relative_attention_num_buckets,
  729. self.n_heads,
  730. embedding_init=jax.nn.initializers.normal(kv_init_std),
  731. )
  732. self.global_input_layer_norm = FlaxLongT5LayerNorm(
  733. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  734. )
  735. @staticmethod
  736. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Attention._relative_position_bucket
  737. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  738. """
  739. Adapted from Mesh Tensorflow:
  740. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  741. Translate relative position to a bucket number for relative attention. The relative position is defined as
  742. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  743. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  744. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  745. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  746. This should allow for more graceful generalization to longer sequences than the model has been trained on
  747. """
  748. relative_buckets = 0
  749. if bidirectional:
  750. num_buckets //= 2
  751. relative_buckets += (relative_position > 0) * num_buckets
  752. relative_position = jnp.abs(relative_position)
  753. else:
  754. relative_position = -jnp.clip(relative_position, a_max=0)
  755. # now relative_position is in the range [0, inf)
  756. # half of the buckets are for exact increments in positions
  757. max_exact = num_buckets // 2
  758. is_small = relative_position < max_exact
  759. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  760. relative_position_if_large = max_exact + (
  761. jnp.log(relative_position / max_exact) / jnp.log(max_distance / max_exact) * (num_buckets - max_exact)
  762. )
  763. relative_position_if_large = jnp.clip(relative_position_if_large, a_max=num_buckets - 1)
  764. relative_buckets += jnp.where(is_small, relative_position, relative_position_if_large)
  765. return relative_buckets.astype("i4")
  766. def compute_bias(self, block_length: int):
  767. """Compute binned relative position bias"""
  768. memory_position = jnp.arange(3 * block_length, dtype="i4")
  769. context_position = memory_position[block_length:-block_length]
  770. relative_position = memory_position[None, :] - context_position[:, None]
  771. relative_position_bucket = self._relative_position_bucket(
  772. relative_position,
  773. bidirectional=True,
  774. num_buckets=self.relative_attention_num_buckets,
  775. max_distance=self.relative_attention_max_distance,
  776. )
  777. values = self.relative_attention_bias(relative_position_bucket)
  778. values = values.transpose((2, 0, 1))[None, None, :, :, :]
  779. return values
  780. def compute_side_bias(self, attention_mask: np.ndarray, global_segment_ids: np.ndarray) -> np.ndarray:
  781. # (batch_size, 1, 1, seq_len, global_seq_len)
  782. side_attention_mask = jnp.equal(attention_mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
  783. attention_side_bias = jax.lax.select(
  784. side_attention_mask > 0,
  785. jnp.full(side_attention_mask.shape, 0.0).astype(self.dtype),
  786. jnp.full(side_attention_mask.shape, -1e10).astype(self.dtype),
  787. )
  788. # (batch_size, seq_len, global_seq_len)
  789. side_relative_position = _make_side_relative_position_ids(attention_mask, self.global_block_size)
  790. side_relative_position_bucket = self._relative_position_bucket(
  791. side_relative_position,
  792. bidirectional=True,
  793. num_buckets=self.relative_attention_num_buckets,
  794. max_distance=self.relative_attention_max_distance,
  795. )
  796. # (batch_size, seq_len, global_seq_len, num_heads)
  797. side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
  798. # (batch_size, 1, num_heads, seq_len, global_seq_len)
  799. side_bias = jnp.transpose(side_bias, (0, 3, 1, 2))
  800. # (batch_size, num_heads, seq_len, global_seq_len)
  801. attention_side_bias = attention_side_bias + side_bias
  802. return attention_side_bias
  803. def _split_heads(self, hidden_states):
  804. return hidden_states.reshape(hidden_states.shape[:2] + (self.n_heads, self.key_value_proj_dim))
  805. def _merge_heads(self, hidden_states):
  806. return hidden_states.reshape(hidden_states.shape[0], -1, self.inner_dim)
  807. def _create_position_bias(self, block_len: int, attention_mask: Optional[np.ndarray]) -> np.ndarray:
  808. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  809. if self.has_relative_attention_bias:
  810. position_bias = self.compute_bias(block_len)
  811. elif attention_mask is not None:
  812. position_bias = jnp.zeros_like(attention_mask)
  813. else:
  814. position_bias = jnp.zeros((1, 1, self.n_heads, block_len, 3 * block_len), dtype=self.dtype)
  815. return position_bias
  816. def __call__(
  817. self,
  818. hidden_states,
  819. attention_mask=None,
  820. key_value_states=None,
  821. position_bias=None,
  822. output_attentions=False,
  823. deterministic=True,
  824. ):
  825. """
  826. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  827. """
  828. batch_size, seq_length = hidden_states.shape[:2]
  829. # Prepare components for transient-global attention
  830. # Obtain block_ids and global_segment_ids
  831. # global_seq_len := seq_len // self.global_block_size
  832. # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
  833. block_ids, global_segment_ids = _make_global_fixed_block_ids(
  834. attention_mask if attention_mask is not None else jnp.ones((batch_size, seq_length)),
  835. self.global_block_size,
  836. )
  837. # Create global inputs
  838. _global_seq_len = global_segment_ids.shape[-1]
  839. global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
  840. global_inputs = self.global_input_layer_norm(global_inputs)
  841. # q, k, v projections
  842. query_states = self.q(hidden_states) # (batch_size, n_heads, seq_length, dim_per_head)
  843. key_states = self.k(hidden_states) if key_value_states is None else self.k(key_value_states)
  844. value_states = self.v(hidden_states) if key_value_states is None else self.v(key_value_states)
  845. # reshape to (batch_size, seq_length, n_heads, head_dim)
  846. query_states = self._split_heads(query_states)
  847. key_states = self._split_heads(key_states)
  848. value_states = self._split_heads(value_states)
  849. # Get global/side key/value_states
  850. side_key_states = self.k(global_inputs)
  851. side_value_states = self.v(global_inputs)
  852. # reshape to (batch_size, global_seq_len, n_heads, head_dim)
  853. side_key_states = self._split_heads(side_key_states)
  854. side_value_states = self._split_heads(side_value_states)
  855. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, head_dim)
  856. query_states = _split_into_blocks(query_states, self.block_len, axis=1)
  857. key_states = _split_into_blocks(key_states, self.block_len, axis=1)
  858. value_states = _split_into_blocks(value_states, self.block_len, axis=1)
  859. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  860. key_states = _concatenate_3_blocks(key_states, block_axis=1, sequence_axis=2)
  861. value_states = _concatenate_3_blocks(value_states, block_axis=1, sequence_axis=2)
  862. # Tile side inputs across local key/value blocks
  863. # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
  864. reps = [1] * (side_key_states.ndim + 1)
  865. reps[1] = key_states.shape[1]
  866. side_key_states = jnp.tile(side_key_states[:, None, ...], reps)
  867. side_value_states = jnp.tile(side_value_states[:, None, ...], reps)
  868. # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
  869. # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
  870. key_states = jnp.concatenate((key_states, side_key_states), axis=2)
  871. value_states = jnp.concatenate((value_states, side_value_states), axis=2)
  872. # counter-act scaling in dot_product_attention_weights function
  873. query_states *= jnp.sqrt(query_states.shape[-1])
  874. if attention_mask is not None:
  875. local_attention_mask = _get_local_attention_mask(attention_mask, self.block_len)
  876. local_attention_mask = jax.lax.select(
  877. local_attention_mask > 0,
  878. jnp.full(local_attention_mask.shape, 0.0).astype(self.dtype),
  879. jnp.full(local_attention_mask.shape, -1e10).astype(self.dtype),
  880. )
  881. else:
  882. local_attention_mask = None
  883. if position_bias is None:
  884. # compute position bias (only for first layer)
  885. position_bias = self._create_position_bias(self.block_len, attention_mask)
  886. if local_attention_mask is not None:
  887. position_bias = position_bias + local_attention_mask.swapaxes(1, 2)
  888. # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
  889. if attention_mask is None:
  890. attention_mask = jnp.ones((batch_size, seq_length))
  891. side_position_bias = self.compute_side_bias(attention_mask, global_segment_ids)
  892. side_position_bias = _split_into_blocks(side_position_bias, self.block_len, axis=-2)
  893. side_position_bias = jnp.swapaxes(side_position_bias, 1, 2)
  894. position_bias = jnp.concatenate((position_bias, side_position_bias), axis=-1)
  895. # create dropout rng
  896. dropout_rng = None
  897. if not deterministic and self.dropout > 0.0:
  898. dropout_rng = self.make_rng("dropout")
  899. # Softmax(QK^T)
  900. attn_weights = dot_product_attention_weights(
  901. query_states,
  902. key_states,
  903. bias=position_bias,
  904. dropout_rng=dropout_rng,
  905. dropout_rate=self.dropout,
  906. broadcast_dropout=True,
  907. deterministic=deterministic,
  908. dtype=self.dtype,
  909. )
  910. # multiply with value states
  911. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  912. # bring back to (batch_size, seq_length, d_model)
  913. attn_output = self._merge_heads(attn_output)
  914. attn_output = attn_output[:, :seq_length, :]
  915. # apply output matrix
  916. attn_output = self.o(attn_output)
  917. outputs = (attn_output, position_bias)
  918. if output_attentions:
  919. outputs = outputs + (attn_weights,)
  920. return outputs
  921. class FlaxLongT5LayerLocalSelfAttention(nn.Module):
  922. """Local self attention used in encoder"""
  923. config: LongT5Config
  924. has_relative_attention_bias: bool = False
  925. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  926. def setup(self):
  927. self.LocalSelfAttention = FlaxLongT5LocalAttention(
  928. self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
  929. )
  930. self.layer_norm = FlaxLongT5LayerNorm(
  931. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  932. )
  933. self.dropout = nn.Dropout(self.config.dropout_rate)
  934. def __call__(
  935. self,
  936. hidden_states,
  937. attention_mask=None,
  938. position_bias=None,
  939. output_attentions=False,
  940. deterministic=True,
  941. **kwargs: Any, # to accept init_cache kwargs
  942. ):
  943. normed_hidden_states = self.layer_norm(hidden_states)
  944. attention_output = self.LocalSelfAttention(
  945. normed_hidden_states,
  946. attention_mask=attention_mask,
  947. position_bias=position_bias,
  948. output_attentions=output_attentions,
  949. deterministic=deterministic,
  950. )
  951. hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
  952. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  953. return outputs
  954. class FlaxLongT5LayerTransientGlobalSelfAttention(nn.Module):
  955. """Transient-Global self attention used in encoder"""
  956. config: LongT5Config
  957. has_relative_attention_bias: bool = False
  958. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  959. def setup(self):
  960. self.TransientGlobalSelfAttention = FlaxLongT5TransientGlobalAttention(
  961. self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
  962. )
  963. self.layer_norm = FlaxLongT5LayerNorm(
  964. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  965. )
  966. self.dropout = nn.Dropout(self.config.dropout_rate)
  967. def __call__(
  968. self,
  969. hidden_states,
  970. attention_mask=None,
  971. position_bias=None,
  972. output_attentions=False,
  973. deterministic=True,
  974. **kwargs: Any, # to accept init_cache kwargs
  975. ):
  976. normed_hidden_states = self.layer_norm(hidden_states)
  977. attention_output = self.TransientGlobalSelfAttention(
  978. normed_hidden_states,
  979. attention_mask=attention_mask,
  980. position_bias=position_bias,
  981. output_attentions=output_attentions,
  982. deterministic=deterministic,
  983. )
  984. hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
  985. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  986. return outputs
  987. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerSelfAttention with T5->LongT5
  988. class FlaxLongT5LayerSelfAttention(nn.Module):
  989. config: LongT5Config
  990. has_relative_attention_bias: bool = False
  991. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  992. def setup(self):
  993. self.SelfAttention = FlaxLongT5Attention(
  994. self.config,
  995. has_relative_attention_bias=self.has_relative_attention_bias,
  996. causal=self.config.causal,
  997. dtype=self.dtype,
  998. )
  999. self.layer_norm = FlaxLongT5LayerNorm(
  1000. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  1001. )
  1002. self.dropout = nn.Dropout(self.config.dropout_rate)
  1003. def __call__(
  1004. self,
  1005. hidden_states,
  1006. attention_mask=None,
  1007. position_bias=None,
  1008. output_attentions=False,
  1009. deterministic=True,
  1010. init_cache=False,
  1011. ):
  1012. normed_hidden_states = self.layer_norm(hidden_states)
  1013. attention_output = self.SelfAttention(
  1014. normed_hidden_states,
  1015. attention_mask=attention_mask,
  1016. position_bias=position_bias,
  1017. output_attentions=output_attentions,
  1018. deterministic=deterministic,
  1019. init_cache=init_cache,
  1020. )
  1021. hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
  1022. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  1023. return outputs
  1024. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCrossAttention with T5->LongT5
  1025. class FlaxLongT5LayerCrossAttention(nn.Module):
  1026. config: LongT5Config
  1027. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1028. def setup(self):
  1029. self.EncDecAttention = FlaxLongT5Attention(
  1030. self.config, has_relative_attention_bias=False, causal=False, dtype=self.dtype
  1031. )
  1032. self.layer_norm = FlaxLongT5LayerNorm(
  1033. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  1034. )
  1035. self.dropout = nn.Dropout(self.config.dropout_rate)
  1036. def __call__(
  1037. self,
  1038. hidden_states,
  1039. key_value_states,
  1040. attention_mask=None,
  1041. position_bias=None,
  1042. output_attentions=False,
  1043. deterministic=True,
  1044. ):
  1045. normed_hidden_states = self.layer_norm(hidden_states)
  1046. attention_output = self.EncDecAttention(
  1047. normed_hidden_states,
  1048. attention_mask=attention_mask,
  1049. key_value_states=key_value_states,
  1050. position_bias=position_bias,
  1051. output_attentions=output_attentions,
  1052. )
  1053. hidden_states = hidden_states + self.dropout(attention_output[0], deterministic=deterministic)
  1054. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  1055. return outputs
  1056. class FlaxLongT5Block(nn.Module):
  1057. config: LongT5Config
  1058. has_relative_attention_bias: bool = False
  1059. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1060. def setup(self):
  1061. self.causal = self.config.causal
  1062. if self.causal:
  1063. attention_layer = FlaxLongT5LayerSelfAttention
  1064. elif self.config.encoder_attention_type == "local":
  1065. attention_layer = FlaxLongT5LayerLocalSelfAttention
  1066. elif self.config.encoder_attention_type == "transient-global":
  1067. attention_layer = FlaxLongT5LayerTransientGlobalSelfAttention
  1068. else:
  1069. raise ValueError(
  1070. "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
  1071. f"but got {self.config.encoder_attention_type}."
  1072. )
  1073. self.layer = (
  1074. attention_layer(
  1075. self.config,
  1076. has_relative_attention_bias=self.has_relative_attention_bias,
  1077. name=str(0),
  1078. dtype=self.dtype,
  1079. ),
  1080. )
  1081. feed_forward_index = 1
  1082. if self.causal:
  1083. self.layer += (FlaxLongT5LayerCrossAttention(self.config, name=str(1), dtype=self.dtype),)
  1084. feed_forward_index += 1
  1085. self.layer += (FlaxLongT5LayerFF(self.config, name=str(feed_forward_index), dtype=self.dtype),)
  1086. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Block.__call__ with T5->LongT5
  1087. def __call__(
  1088. self,
  1089. hidden_states,
  1090. attention_mask=None,
  1091. position_bias=None,
  1092. encoder_hidden_states=None,
  1093. encoder_attention_mask=None,
  1094. encoder_decoder_position_bias=None,
  1095. output_attentions=False,
  1096. return_dict=True,
  1097. deterministic=True,
  1098. init_cache=False,
  1099. ):
  1100. self_attention_outputs = self.layer[0](
  1101. hidden_states,
  1102. attention_mask=attention_mask,
  1103. position_bias=position_bias,
  1104. output_attentions=output_attentions,
  1105. deterministic=deterministic,
  1106. init_cache=init_cache,
  1107. )
  1108. hidden_states = self_attention_outputs[0]
  1109. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  1110. do_cross_attention = self.causal and encoder_hidden_states is not None
  1111. if do_cross_attention:
  1112. cross_attention_outputs = self.layer[1](
  1113. hidden_states,
  1114. key_value_states=encoder_hidden_states,
  1115. attention_mask=encoder_attention_mask,
  1116. position_bias=encoder_decoder_position_bias,
  1117. output_attentions=output_attentions,
  1118. deterministic=deterministic,
  1119. )
  1120. hidden_states = cross_attention_outputs[0]
  1121. # Keep cross-attention outputs and relative position weights
  1122. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  1123. # Apply Feed Forward layer
  1124. hidden_states = self.layer[-1](hidden_states, deterministic=deterministic)
  1125. outputs = (hidden_states,)
  1126. outputs = outputs + attention_outputs
  1127. # returns hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights),
  1128. # (cross-attention position bias), (cross-attention weights)
  1129. return outputs
  1130. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5LayerCollection with T5->LongT5
  1131. class FlaxLongT5LayerCollection(nn.Module):
  1132. config: LongT5Config
  1133. has_relative_attention_bias: bool
  1134. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1135. def setup(self):
  1136. self.layer = FlaxLongT5Block(
  1137. self.config, has_relative_attention_bias=self.has_relative_attention_bias, dtype=self.dtype
  1138. )
  1139. def __call__(
  1140. self,
  1141. hidden_states,
  1142. attention_mask=None,
  1143. position_bias=None,
  1144. encoder_hidden_states=None,
  1145. encoder_attention_mask=None,
  1146. encoder_decoder_position_bias=None,
  1147. output_attentions=False,
  1148. deterministic=True,
  1149. init_cache=False,
  1150. ):
  1151. return self.layer(
  1152. hidden_states,
  1153. attention_mask=attention_mask,
  1154. position_bias=position_bias,
  1155. encoder_hidden_states=encoder_hidden_states,
  1156. encoder_attention_mask=encoder_attention_mask,
  1157. encoder_decoder_position_bias=encoder_decoder_position_bias,
  1158. output_attentions=output_attentions,
  1159. deterministic=deterministic,
  1160. init_cache=init_cache,
  1161. )
  1162. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5BlockCollection with T5->LongT5
  1163. class FlaxLongT5BlockCollection(nn.Module):
  1164. config: LongT5Config
  1165. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1166. gradient_checkpointing: bool = False
  1167. def setup(self):
  1168. self.causal = self.config.causal
  1169. if self.gradient_checkpointing:
  1170. FlaxLongT5CheckpointLayer = remat(FlaxLongT5LayerCollection, static_argnums=(6, 7, 8))
  1171. self.blocks = [
  1172. FlaxLongT5CheckpointLayer(
  1173. self.config,
  1174. has_relative_attention_bias=(i == 0),
  1175. dtype=self.dtype,
  1176. name=str(i),
  1177. )
  1178. for i in range(self.config.num_layers)
  1179. ]
  1180. else:
  1181. self.blocks = [
  1182. FlaxLongT5LayerCollection(
  1183. self.config,
  1184. has_relative_attention_bias=(i == 0),
  1185. dtype=self.dtype,
  1186. name=str(i),
  1187. )
  1188. for i in range(self.config.num_layers)
  1189. ]
  1190. def __call__(
  1191. self,
  1192. hidden_states=None,
  1193. attention_mask=None,
  1194. encoder_hidden_states=None,
  1195. encoder_attention_mask=None,
  1196. output_attentions: bool = False,
  1197. output_hidden_states: bool = False,
  1198. deterministic: bool = True,
  1199. init_cache: bool = False,
  1200. ):
  1201. # Prepare head mask if needed
  1202. all_hidden_states = () if output_hidden_states else None
  1203. all_attentions = () if output_attentions else None
  1204. all_cross_attentions = () if (output_attentions and self.causal) else None
  1205. position_bias = None
  1206. encoder_decoder_position_bias = None
  1207. for i, layer_module in enumerate(self.blocks):
  1208. if output_hidden_states:
  1209. all_hidden_states = all_hidden_states + (hidden_states,)
  1210. layer_outputs = layer_module(
  1211. hidden_states,
  1212. attention_mask,
  1213. position_bias,
  1214. encoder_hidden_states,
  1215. encoder_attention_mask,
  1216. encoder_decoder_position_bias,
  1217. output_attentions,
  1218. deterministic,
  1219. init_cache,
  1220. )
  1221. hidden_states = layer_outputs[0]
  1222. # We share the position biases between the layers - the first layer store them
  1223. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1224. # (cross-attention position bias), (cross-attention weights)
  1225. position_bias = layer_outputs[1]
  1226. if self.causal and encoder_hidden_states is not None:
  1227. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1228. if output_attentions:
  1229. all_attentions = all_attentions + (layer_outputs[2],)
  1230. if self.causal:
  1231. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1232. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  1233. last_hidden_state=hidden_states,
  1234. hidden_states=all_hidden_states,
  1235. attentions=all_attentions,
  1236. cross_attentions=all_cross_attentions,
  1237. )
  1238. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Stack with T5->LongT5
  1239. class FlaxLongT5Stack(nn.Module):
  1240. config: LongT5Config
  1241. embed_tokens: nn.Embed
  1242. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1243. gradient_checkpointing: bool = False
  1244. def setup(self):
  1245. self.causal = self.config.causal
  1246. self.block = FlaxLongT5BlockCollection(
  1247. self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1248. )
  1249. self.final_layer_norm = FlaxLongT5LayerNorm(
  1250. self.config.d_model, eps=self.config.layer_norm_epsilon, dtype=self.dtype
  1251. )
  1252. self.dropout = nn.Dropout(self.config.dropout_rate)
  1253. def __call__(
  1254. self,
  1255. input_ids=None,
  1256. attention_mask=None,
  1257. encoder_hidden_states=None,
  1258. encoder_attention_mask=None,
  1259. output_attentions: bool = False,
  1260. output_hidden_states: bool = False,
  1261. return_dict: bool = True,
  1262. deterministic: bool = True,
  1263. init_cache: bool = False,
  1264. ):
  1265. hidden_states = self.embed_tokens(input_ids)
  1266. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  1267. outputs = self.block(
  1268. hidden_states,
  1269. attention_mask=attention_mask,
  1270. encoder_hidden_states=encoder_hidden_states,
  1271. encoder_attention_mask=encoder_attention_mask,
  1272. output_attentions=output_attentions,
  1273. output_hidden_states=output_hidden_states,
  1274. deterministic=deterministic,
  1275. init_cache=init_cache,
  1276. )
  1277. hidden_states = outputs[0]
  1278. hidden_states = self.final_layer_norm(hidden_states)
  1279. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  1280. # Add last layer
  1281. all_hidden_states = None
  1282. if output_hidden_states:
  1283. all_hidden_states = outputs.hidden_states
  1284. all_hidden_states = all_hidden_states + (hidden_states,)
  1285. if not return_dict:
  1286. if output_hidden_states:
  1287. return (
  1288. hidden_states,
  1289. all_hidden_states,
  1290. ) + outputs[2:]
  1291. return (hidden_states,) + outputs[1:]
  1292. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  1293. last_hidden_state=hidden_states,
  1294. hidden_states=all_hidden_states,
  1295. attentions=outputs.attentions,
  1296. cross_attentions=outputs.cross_attentions,
  1297. )
  1298. LONGT5_ENCODE_INPUTS_DOCSTRING = r"""
  1299. Args:
  1300. input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  1301. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1302. you should be able to pad the inputs on both the right and the left.
  1303. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1304. [`PreTrainedTokenizer.__call__`] for detail.
  1305. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1306. Training](./longt5#training).
  1307. attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1308. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1309. - 1 for tokens that are **not masked**,
  1310. - 0 for tokens that are **masked**.
  1311. [What are attention masks?](../glossary#attention-mask)
  1312. output_attentions (`bool`, *optional*):
  1313. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1314. tensors for more detail.
  1315. output_hidden_states (`bool`, *optional*):
  1316. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1317. more detail.
  1318. return_dict (`bool`, *optional*):
  1319. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1320. """
  1321. LONGT5_DECODE_INPUTS_DOCSTRING = r"""
  1322. Args:
  1323. decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
  1324. Indices of decoder input sequence tokens in the vocabulary.
  1325. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1326. [`PreTrainedTokenizer.__call__`] for details.
  1327. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1328. For training, `decoder_input_ids` should be provided.
  1329. encoder_outputs (`tuple(tuple(jnp.ndarray)`):
  1330. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  1331. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  1332. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1333. encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1334. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1335. - 1 for tokens that are **not masked**,
  1336. - 0 for tokens that are **masked**.
  1337. [What are attention masks?](../glossary#attention-mask)
  1338. decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  1339. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1340. be used by default.
  1341. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
  1342. paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
  1343. past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
  1344. Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
  1345. auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
  1346. output_attentions (`bool`, *optional*):
  1347. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1348. tensors for more detail.
  1349. output_hidden_states (`bool`, *optional*):
  1350. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1351. more detail.
  1352. return_dict (`bool`, *optional*):
  1353. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1354. """
  1355. LONGT5_INPUTS_DOCSTRING = r"""
  1356. Args:
  1357. input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  1358. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1359. you should be able to pad the inputs on both the right and the left.
  1360. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1361. [`PreTrainedTokenizer.__call__`] for detail.
  1362. [What are input IDs?](../glossary#input-ids)
  1363. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1364. Training](./longt5#training).
  1365. attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1366. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1367. - 1 for tokens that are **not masked**,
  1368. - 0 for tokens that are **masked**.
  1369. [What are attention masks?](../glossary#attention-mask)
  1370. decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  1371. Indices of decoder input sequence tokens in the vocabulary.
  1372. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1373. [`PreTrainedTokenizer.__call__`] for details.
  1374. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1375. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1376. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1377. `past_key_values`).
  1378. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1379. Training](./longt5#training).
  1380. decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  1381. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1382. be used by default.
  1383. encoder_outputs (`tuple(tuple(jnp.ndarray)`, *optional*):
  1384. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
  1385. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
  1386. the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1387. past_key_values (`tuple(tuple(jnp.ndarray))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1388. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1389. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1390. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1391. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1392. output_attentions (`bool`, *optional*):
  1393. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1394. tensors for more detail.
  1395. output_hidden_states (`bool`, *optional*):
  1396. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1397. more detail.
  1398. return_dict (`bool`, *optional*):
  1399. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1400. """
  1401. class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
  1402. """
  1403. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1404. models.
  1405. """
  1406. config_class = LongT5Config
  1407. base_model_prefix = "transformer"
  1408. module_class: nn.Module = None
  1409. def __init__(
  1410. self,
  1411. config: LongT5Config,
  1412. input_shape: Tuple[int] = (1, 1),
  1413. seed: int = 0,
  1414. dtype: jnp.dtype = jnp.float32,
  1415. _do_init: bool = True,
  1416. **kwargs,
  1417. ):
  1418. module = self.module_class(config=config, dtype=dtype, **kwargs)
  1419. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  1420. def enable_gradient_checkpointing(self):
  1421. self._module = self.module_class(
  1422. config=self.config,
  1423. dtype=self.dtype,
  1424. gradient_checkpointing=True,
  1425. )
  1426. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  1427. # init input tensors
  1428. input_ids = jnp.zeros(input_shape, dtype="i4")
  1429. attention_mask = jnp.ones_like(input_ids)
  1430. decoder_input_ids = jnp.ones_like(input_ids)
  1431. decoder_attention_mask = jnp.ones_like(input_ids)
  1432. params_rng, dropout_rng = jax.random.split(rng)
  1433. rngs = {"params": params_rng, "dropout": dropout_rng}
  1434. random_params = self.module.init(
  1435. rngs,
  1436. input_ids,
  1437. attention_mask,
  1438. decoder_input_ids,
  1439. decoder_attention_mask,
  1440. )["params"]
  1441. if params is not None:
  1442. random_params = flatten_dict(unfreeze(random_params))
  1443. params = flatten_dict(unfreeze(params))
  1444. for missing_key in self._missing_keys:
  1445. params[missing_key] = random_params[missing_key]
  1446. self._missing_keys = set()
  1447. return freeze(unflatten_dict(params))
  1448. else:
  1449. return random_params
  1450. @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
  1451. def __call__(
  1452. self,
  1453. input_ids: jnp.ndarray,
  1454. attention_mask: Optional[jnp.ndarray] = None,
  1455. decoder_input_ids: jnp.ndarray = None,
  1456. decoder_attention_mask: Optional[jnp.ndarray] = None,
  1457. output_attentions: Optional[bool] = None,
  1458. output_hidden_states: Optional[bool] = None,
  1459. return_dict: Optional[bool] = None,
  1460. train: bool = False,
  1461. params: dict = None,
  1462. dropout_rng: PRNGKey = None,
  1463. ):
  1464. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1465. output_hidden_states = (
  1466. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1467. )
  1468. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1469. if decoder_input_ids is None:
  1470. raise ValueError(
  1471. "Make sure to provide both `input_ids` and `decoder_input_ids`. `decoder_input_ids` is not passed"
  1472. " here."
  1473. )
  1474. # prepare encoder inputs
  1475. if attention_mask is None:
  1476. attention_mask = jnp.ones_like(input_ids)
  1477. # prepare decoder inputs
  1478. if decoder_attention_mask is None:
  1479. decoder_attention_mask = jnp.ones_like(decoder_input_ids)
  1480. # Handle any PRNG if needed
  1481. rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
  1482. return self.module.apply(
  1483. {"params": params or self.params},
  1484. input_ids=jnp.array(input_ids, dtype="i4"),
  1485. attention_mask=jnp.array(attention_mask, dtype="i4"),
  1486. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  1487. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  1488. output_attentions=output_attentions,
  1489. output_hidden_states=output_hidden_states,
  1490. return_dict=return_dict,
  1491. deterministic=not train,
  1492. rngs=rngs,
  1493. )
  1494. def init_cache(self, batch_size, max_length, encoder_outputs):
  1495. r"""
  1496. Args:
  1497. batch_size (`int`):
  1498. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  1499. max_length (`int`):
  1500. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  1501. cache.
  1502. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
  1503. `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
  1504. `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
  1505. is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
  1506. cross-attention of the decoder.
  1507. """
  1508. # init input variables to retrieve cache
  1509. decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  1510. decoder_attention_mask = jnp.ones_like(decoder_input_ids)
  1511. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
  1512. decoder_module = module._get_decoder_module()
  1513. return decoder_module(
  1514. decoder_input_ids,
  1515. decoder_attention_mask,
  1516. **kwargs,
  1517. )
  1518. init_variables = self.module.init(
  1519. jax.random.PRNGKey(0),
  1520. decoder_input_ids=decoder_input_ids,
  1521. decoder_attention_mask=decoder_attention_mask,
  1522. encoder_hidden_states=encoder_outputs[0],
  1523. init_cache=True,
  1524. method=_decoder_forward, # we only need to call the decoder to init the cache
  1525. )
  1526. return unfreeze(init_variables["cache"])
  1527. @add_start_docstrings(LONGT5_ENCODE_INPUTS_DOCSTRING)
  1528. @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=LongT5Config)
  1529. def encode(
  1530. self,
  1531. input_ids: jnp.ndarray,
  1532. attention_mask: Optional[jnp.ndarray] = None,
  1533. output_attentions: Optional[bool] = None,
  1534. output_hidden_states: Optional[bool] = None,
  1535. return_dict: Optional[bool] = None,
  1536. train: bool = False,
  1537. params: dict = None,
  1538. dropout_rng: PRNGKey = None,
  1539. ):
  1540. r"""
  1541. Returns:
  1542. Example:
  1543. ```python
  1544. >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
  1545. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
  1546. >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
  1547. >>> text = "My friends are cool but they eat too many carbs."
  1548. >>> inputs = tokenizer(text, return_tensors="np")
  1549. >>> encoder_outputs = model.encode(**inputs)
  1550. ```"""
  1551. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1552. output_hidden_states = (
  1553. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1554. )
  1555. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1556. if attention_mask is None:
  1557. attention_mask = jnp.ones_like(input_ids)
  1558. # Handle any PRNG if needed
  1559. rngs = {}
  1560. if dropout_rng is not None:
  1561. rngs["dropout"] = dropout_rng
  1562. def _encoder_forward(module, input_ids, attention_mask, **kwargs):
  1563. encode_module = module._get_encoder_module()
  1564. return encode_module(input_ids, attention_mask, **kwargs)
  1565. return self.module.apply(
  1566. {"params": params or self.params},
  1567. input_ids=jnp.array(input_ids, dtype="i4"),
  1568. attention_mask=jnp.array(attention_mask, dtype="i4"),
  1569. output_attentions=output_attentions,
  1570. output_hidden_states=output_hidden_states,
  1571. return_dict=return_dict,
  1572. deterministic=not train,
  1573. rngs=rngs,
  1574. method=_encoder_forward,
  1575. )
  1576. @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
  1577. @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=LongT5Config)
  1578. def decode(
  1579. self,
  1580. decoder_input_ids,
  1581. encoder_outputs,
  1582. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1583. decoder_attention_mask: Optional[jnp.ndarray] = None,
  1584. past_key_values: dict = None,
  1585. output_attentions: Optional[bool] = None,
  1586. output_hidden_states: Optional[bool] = None,
  1587. return_dict: Optional[bool] = None,
  1588. train: bool = False,
  1589. params: dict = None,
  1590. dropout_rng: PRNGKey = None,
  1591. ):
  1592. r"""
  1593. Returns:
  1594. Example:
  1595. ```python
  1596. >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
  1597. >>> import jax.numpy as jnp
  1598. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
  1599. >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
  1600. >>> text = "My friends are cool but they eat too many carbs."
  1601. >>> inputs = tokenizer(text, return_tensors="np")
  1602. >>> encoder_outputs = model.encode(**inputs)
  1603. >>> decoder_start_token_id = model.config.decoder_start_token_id
  1604. >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
  1605. >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
  1606. >>> logits = outputs.logits
  1607. ```"""
  1608. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1609. output_hidden_states = (
  1610. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1611. )
  1612. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1613. encoder_hidden_states = encoder_outputs[0]
  1614. if encoder_attention_mask is None:
  1615. batch_size, sequence_length = encoder_hidden_states.shape[:2]
  1616. encoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1617. batch_size, sequence_length = decoder_input_ids.shape
  1618. if decoder_attention_mask is None:
  1619. decoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1620. # Handle any PRNG if needed
  1621. rngs = {}
  1622. if dropout_rng is not None:
  1623. rngs["dropout"] = dropout_rng
  1624. inputs = {"params": params or self.params}
  1625. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
  1626. # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
  1627. # it can be changed by FlaxLongT5Attention module
  1628. if past_key_values:
  1629. inputs["cache"] = past_key_values
  1630. mutable = ["cache"]
  1631. else:
  1632. mutable = False
  1633. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
  1634. decoder_module = module._get_decoder_module()
  1635. return decoder_module(
  1636. decoder_input_ids,
  1637. decoder_attention_mask,
  1638. **kwargs,
  1639. )
  1640. outputs = self.module.apply(
  1641. inputs,
  1642. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  1643. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  1644. encoder_hidden_states=encoder_hidden_states,
  1645. encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
  1646. output_attentions=output_attentions,
  1647. output_hidden_states=output_hidden_states,
  1648. return_dict=return_dict,
  1649. deterministic=not train,
  1650. rngs=rngs,
  1651. mutable=mutable,
  1652. method=_decoder_forward,
  1653. )
  1654. # add updated cache to model output
  1655. if past_key_values is not None and return_dict:
  1656. outputs, past = outputs
  1657. outputs["past_key_values"] = unfreeze(past["cache"])
  1658. return outputs
  1659. elif past_key_values is not None and not return_dict:
  1660. outputs, past = outputs
  1661. outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
  1662. return outputs
  1663. LONGT5_START_DOCSTRING = r"""
  1664. The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
  1665. Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
  1666. Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
  1667. generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
  1668. efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
  1669. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  1670. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1671. etc.)
  1672. This model is also a Flax Linen
  1673. [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
  1674. regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
  1675. Finally, this model supports inherent JAX features such as:
  1676. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  1677. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  1678. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  1679. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  1680. Parameters:
  1681. config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
  1682. Initializing with a config file does not load the weights associated with the model, only the
  1683. configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
  1684. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
  1685. The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
  1686. `jax.numpy.bfloat16` (on TPUs).
  1687. This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
  1688. specified all the computation will be performed with the given `dtype`.
  1689. **Note that this only specifies the dtype of the computation and does not influence the dtype of model
  1690. parameters.**
  1691. If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
  1692. [`~FlaxPreTrainedModel.to_bf16`].
  1693. """
  1694. @add_start_docstrings(
  1695. "The bare LONGT5 Model transformer outputting raw hidden-stateswithout any specific head on top.",
  1696. LONGT5_START_DOCSTRING,
  1697. )
  1698. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Module with T5->LongT5
  1699. class FlaxLongT5Module(nn.Module):
  1700. config: LongT5Config
  1701. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1702. gradient_checkpointing: bool = False
  1703. def _get_encoder_module(self):
  1704. return self.encoder
  1705. def _get_decoder_module(self):
  1706. return self.decoder
  1707. def setup(self):
  1708. self.shared = nn.Embed(
  1709. self.config.vocab_size,
  1710. self.config.d_model,
  1711. embedding_init=jax.nn.initializers.normal(self.config.initializer_factor * 1.0),
  1712. dtype=self.dtype,
  1713. )
  1714. encoder_config = copy.deepcopy(self.config)
  1715. encoder_config.causal = False
  1716. self.encoder = FlaxLongT5Stack(
  1717. encoder_config,
  1718. embed_tokens=self.shared,
  1719. dtype=self.dtype,
  1720. gradient_checkpointing=self.gradient_checkpointing,
  1721. )
  1722. decoder_config = copy.deepcopy(self.config)
  1723. decoder_config.causal = True
  1724. decoder_config.num_layers = self.config.num_decoder_layers
  1725. self.decoder = FlaxLongT5Stack(
  1726. decoder_config,
  1727. embed_tokens=self.shared,
  1728. dtype=self.dtype,
  1729. gradient_checkpointing=self.gradient_checkpointing,
  1730. )
  1731. def __call__(
  1732. self,
  1733. input_ids=None,
  1734. attention_mask=None,
  1735. decoder_input_ids=None,
  1736. decoder_attention_mask=None,
  1737. encoder_outputs=None,
  1738. output_attentions=None,
  1739. output_hidden_states=None,
  1740. return_dict=None,
  1741. deterministic: bool = True,
  1742. ):
  1743. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1744. # Encode if needed (training, first prediction pass)
  1745. encoder_outputs = self.encoder(
  1746. input_ids=input_ids,
  1747. attention_mask=attention_mask,
  1748. output_attentions=output_attentions,
  1749. output_hidden_states=output_hidden_states,
  1750. return_dict=return_dict,
  1751. deterministic=deterministic,
  1752. )
  1753. # Decode
  1754. decoder_outputs = self.decoder(
  1755. input_ids=decoder_input_ids,
  1756. attention_mask=decoder_attention_mask,
  1757. encoder_hidden_states=encoder_outputs[0],
  1758. encoder_attention_mask=attention_mask,
  1759. output_attentions=output_attentions,
  1760. output_hidden_states=output_hidden_states,
  1761. return_dict=return_dict,
  1762. deterministic=deterministic,
  1763. )
  1764. if not return_dict:
  1765. return decoder_outputs + encoder_outputs
  1766. return FlaxSeq2SeqModelOutput(
  1767. last_hidden_state=decoder_outputs.last_hidden_state,
  1768. past_key_values=decoder_outputs.past_key_values,
  1769. decoder_hidden_states=decoder_outputs.hidden_states,
  1770. decoder_attentions=decoder_outputs.attentions,
  1771. cross_attentions=decoder_outputs.cross_attentions,
  1772. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1773. encoder_hidden_states=encoder_outputs.hidden_states,
  1774. encoder_attentions=encoder_outputs.attentions,
  1775. )
  1776. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5Model with T5->LongT5
  1777. class FlaxLongT5Model(FlaxLongT5PreTrainedModel):
  1778. module_class = FlaxLongT5Module
  1779. append_call_sample_docstring(FlaxLongT5Model, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
  1780. FLAX_LONGT5_MODEL_DOCSTRING = """
  1781. Returns:
  1782. Example:
  1783. ```python
  1784. >>> from transformers import AutoTokenizer, FlaxLongT5Model
  1785. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
  1786. >>> model = FlaxLongT5Model.from_pretrained("google/long-t5-local-base")
  1787. >>> input_ids = tokenizer(
  1788. ... "Studies have been shown that owning a dog is good for you", return_tensors="np"
  1789. ... ).input_ids
  1790. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="np").input_ids
  1791. >>> # forward pass
  1792. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1793. >>> last_hidden_states = outputs.last_hidden_state
  1794. ```
  1795. """
  1796. overwrite_call_docstring(FlaxLongT5Model, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_MODEL_DOCSTRING)
  1797. append_replace_return_docstrings(FlaxLongT5Model, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1798. @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
  1799. # Copied from transformers.models.t5.modeling_flax_t5.FlaxT5ForConditionalGenerationModule with T5->LongT5
  1800. class FlaxLongT5ForConditionalGenerationModule(nn.Module):
  1801. config: LongT5Config
  1802. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1803. gradient_checkpointing: bool = False
  1804. def _get_encoder_module(self):
  1805. return self.encoder
  1806. def _get_decoder_module(self):
  1807. return self.decoder
  1808. def setup(self):
  1809. self.model_dim = self.config.d_model
  1810. self.shared = nn.Embed(
  1811. self.config.vocab_size,
  1812. self.config.d_model,
  1813. embedding_init=jax.nn.initializers.normal(self.config.initializer_factor),
  1814. dtype=self.dtype,
  1815. )
  1816. encoder_config = copy.deepcopy(self.config)
  1817. encoder_config.causal = False
  1818. encoder_config.use_cache = False
  1819. encoder_config.is_encoder_decoder = False
  1820. self.encoder = FlaxLongT5Stack(
  1821. encoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1822. )
  1823. decoder_config = copy.deepcopy(self.config)
  1824. decoder_config.causal = True
  1825. decoder_config.is_encoder_decoder = False
  1826. decoder_config.num_layers = self.config.num_decoder_layers
  1827. self.decoder = FlaxLongT5Stack(
  1828. decoder_config, self.shared, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1829. )
  1830. self.lm_head = nn.Dense(
  1831. self.config.vocab_size,
  1832. use_bias=False,
  1833. kernel_init=jax.nn.initializers.normal(self.config.initializer_factor),
  1834. dtype=self.dtype,
  1835. )
  1836. def __call__(
  1837. self,
  1838. input_ids=None,
  1839. attention_mask=None,
  1840. decoder_input_ids=None,
  1841. decoder_attention_mask=None,
  1842. encoder_outputs=None,
  1843. output_attentions=None,
  1844. output_hidden_states=None,
  1845. return_dict=None,
  1846. deterministic: bool = True,
  1847. ):
  1848. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1849. # Encode
  1850. encoder_outputs = self.encoder(
  1851. input_ids=input_ids,
  1852. attention_mask=attention_mask,
  1853. output_attentions=output_attentions,
  1854. output_hidden_states=output_hidden_states,
  1855. return_dict=return_dict,
  1856. deterministic=deterministic,
  1857. )
  1858. hidden_states = encoder_outputs[0]
  1859. # Decode
  1860. decoder_outputs = self.decoder(
  1861. input_ids=decoder_input_ids,
  1862. attention_mask=decoder_attention_mask,
  1863. encoder_hidden_states=hidden_states,
  1864. encoder_attention_mask=attention_mask,
  1865. output_attentions=output_attentions,
  1866. output_hidden_states=output_hidden_states,
  1867. return_dict=return_dict,
  1868. deterministic=deterministic,
  1869. )
  1870. sequence_output = decoder_outputs[0]
  1871. if self.config.tie_word_embeddings:
  1872. # Rescale output before projecting on vocab
  1873. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1874. sequence_output = sequence_output * (self.model_dim**-0.5)
  1875. if self.config.tie_word_embeddings:
  1876. shared_embedding = self.shared.variables["params"]["embedding"]
  1877. lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
  1878. else:
  1879. lm_logits = self.lm_head(sequence_output)
  1880. if not return_dict:
  1881. return (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1882. return FlaxSeq2SeqLMOutput(
  1883. logits=lm_logits,
  1884. past_key_values=decoder_outputs.past_key_values,
  1885. decoder_hidden_states=decoder_outputs.hidden_states,
  1886. decoder_attentions=decoder_outputs.attentions,
  1887. cross_attentions=decoder_outputs.cross_attentions,
  1888. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1889. encoder_hidden_states=encoder_outputs.hidden_states,
  1890. encoder_attentions=encoder_outputs.attentions,
  1891. )
  1892. class FlaxLongT5ForConditionalGeneration(FlaxLongT5PreTrainedModel):
  1893. module_class = FlaxLongT5ForConditionalGenerationModule
  1894. @add_start_docstrings(LONGT5_DECODE_INPUTS_DOCSTRING)
  1895. @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=LongT5Config)
  1896. def decode(
  1897. self,
  1898. decoder_input_ids,
  1899. encoder_outputs,
  1900. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1901. decoder_attention_mask: Optional[jnp.ndarray] = None,
  1902. past_key_values: dict = None,
  1903. output_attentions: Optional[bool] = None,
  1904. output_hidden_states: Optional[bool] = None,
  1905. return_dict: Optional[bool] = None,
  1906. train: bool = False,
  1907. params: dict = None,
  1908. dropout_rng: PRNGKey = None,
  1909. ):
  1910. r"""
  1911. Returns:
  1912. Example:
  1913. ```python
  1914. >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
  1915. >>> import jax.numpy as jnp
  1916. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
  1917. >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
  1918. >>> text = "summarize: My friends are cool but they eat too many carbs."
  1919. >>> inputs = tokenizer(text, return_tensors="np")
  1920. >>> encoder_outputs = model.encode(**inputs)
  1921. >>> decoder_start_token_id = model.config.decoder_start_token_id
  1922. >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
  1923. >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
  1924. >>> logits = outputs.logits
  1925. ```"""
  1926. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1927. output_hidden_states = (
  1928. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1929. )
  1930. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1931. encoder_hidden_states = encoder_outputs[0]
  1932. if encoder_attention_mask is None:
  1933. batch_size, sequence_length = encoder_hidden_states.shape[:2]
  1934. encoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1935. batch_size, sequence_length = decoder_input_ids.shape
  1936. if decoder_attention_mask is None:
  1937. decoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1938. # Handle any PRNG if needed
  1939. rngs = {}
  1940. if dropout_rng is not None:
  1941. rngs["dropout"] = dropout_rng
  1942. inputs = {"params": params or self.params}
  1943. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
  1944. # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
  1945. # it can be changed by FlaxLongT5Attention module
  1946. if past_key_values:
  1947. inputs["cache"] = past_key_values
  1948. mutable = ["cache"]
  1949. else:
  1950. mutable = False
  1951. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
  1952. decoder_module = module._get_decoder_module()
  1953. decoder_outputs = decoder_module(
  1954. decoder_input_ids,
  1955. decoder_attention_mask,
  1956. **kwargs,
  1957. )
  1958. sequence_output = decoder_outputs[0]
  1959. if self.config.tie_word_embeddings:
  1960. # Rescale output before projecting on vocab
  1961. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1962. sequence_output = sequence_output * (self.config.d_model**-0.5)
  1963. if self.config.tie_word_embeddings:
  1964. shared_embedding = module.shared.variables["params"]["embedding"]
  1965. lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
  1966. else:
  1967. lm_logits = module.lm_head(sequence_output)
  1968. return lm_logits, decoder_outputs
  1969. outputs = self.module.apply(
  1970. inputs,
  1971. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  1972. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  1973. encoder_hidden_states=encoder_hidden_states,
  1974. encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
  1975. output_attentions=output_attentions,
  1976. output_hidden_states=output_hidden_states,
  1977. return_dict=return_dict,
  1978. deterministic=not train,
  1979. rngs=rngs,
  1980. mutable=mutable,
  1981. method=_decoder_forward,
  1982. )
  1983. if past_key_values is None:
  1984. lm_logits, decoder_outputs = outputs
  1985. else:
  1986. (lm_logits, decoder_outputs), past = outputs
  1987. if return_dict:
  1988. outputs = FlaxCausalLMOutputWithCrossAttentions(
  1989. logits=lm_logits,
  1990. hidden_states=decoder_outputs.hidden_states,
  1991. attentions=decoder_outputs.attentions,
  1992. cross_attentions=decoder_outputs.cross_attentions,
  1993. )
  1994. else:
  1995. outputs = (lm_logits,) + decoder_outputs[1:]
  1996. # add updated cache to model output
  1997. if past_key_values is not None and return_dict:
  1998. outputs["past_key_values"] = unfreeze(past["cache"])
  1999. return outputs
  2000. elif past_key_values is not None and not return_dict:
  2001. outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
  2002. return outputs
  2003. def prepare_inputs_for_generation(
  2004. self,
  2005. decoder_input_ids,
  2006. max_length,
  2007. attention_mask: Optional[jax.Array] = None,
  2008. decoder_attention_mask: Optional[jax.Array] = None,
  2009. encoder_outputs=None,
  2010. **kwargs,
  2011. ):
  2012. # initializing the cache
  2013. batch_size, seq_length = decoder_input_ids.shape
  2014. past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
  2015. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  2016. # But since the decoder uses a causal mask, those positions are masked anyways.
  2017. # Thus we can create a single static attention_mask here, which is more efficient for compilation
  2018. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  2019. if decoder_attention_mask is not None:
  2020. extended_attention_mask = jax.lax.dynamic_update_slice(
  2021. extended_attention_mask, decoder_attention_mask, (0, 0)
  2022. )
  2023. return {
  2024. "past_key_values": past_key_values,
  2025. "encoder_outputs": encoder_outputs,
  2026. "encoder_attention_mask": attention_mask,
  2027. "decoder_attention_mask": extended_attention_mask,
  2028. }
  2029. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  2030. model_kwargs["past_key_values"] = model_outputs.past_key_values
  2031. return model_kwargs
  2032. FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING = """
  2033. Returns:
  2034. Example:
  2035. ```python
  2036. >>> from transformers import AutoTokenizer, FlaxLongT5ForConditionalGeneration
  2037. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
  2038. >>> model = FlaxLongT5ForConditionalGeneration.from_pretrained("google/long-t5-local-base")
  2039. >>> ARTICLE_TO_SUMMARIZE = "summarize: My friends are cool but they eat too many carbs."
  2040. >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], return_tensors="np")
  2041. >>> # Generate Summary
  2042. >>> summary_ids = model.generate(inputs["input_ids"]).sequences
  2043. >>> print(tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False))
  2044. ```
  2045. """
  2046. overwrite_call_docstring(
  2047. FlaxLongT5ForConditionalGeneration, LONGT5_INPUTS_DOCSTRING + FLAX_LONGT5_CONDITIONAL_GENERATION_DOCSTRING
  2048. )
  2049. append_replace_return_docstrings(
  2050. FlaxLongT5ForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
  2051. )