modeling_longt5.py 110 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341
  1. # coding=utf-8
  2. # Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LongT5 model."""
  16. import copy
  17. import math
  18. import warnings
  19. from typing import Any, List, Optional, Tuple, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import AttentionMaskConverter
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
  35. from ...utils import (
  36. DUMMY_INPUTS,
  37. DUMMY_MASK,
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. is_torch_fx_proxy,
  41. is_torchdynamo_compiling,
  42. logging,
  43. replace_return_docstrings,
  44. )
  45. from .configuration_longt5 import LongT5Config
  46. logger = logging.get_logger(__name__)
  47. _CONFIG_FOR_DOC = "LongT5Config"
  48. _CHECKPOINT_FOR_DOC = "google/long-t5-local-base"
  49. # TODO: Update before the merge
  50. def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
  51. """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
  52. pad_len = -x.shape[dim] % block_len
  53. # Handle cases when an empty input sequence is given
  54. if not all(x.shape):
  55. new_shape = list(x.shape)
  56. new_shape[dim] += pad_len
  57. return torch.zeros(new_shape, dtype=x.dtype)
  58. pad = [(0, 0)] * x.ndim
  59. pad[dim] = (0, pad_len)
  60. pad = sum(pad[::-1], ())
  61. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  62. return x
  63. def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
  64. """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
  65. is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
  66. """
  67. # pad tensor to multiple of block_len
  68. if x.shape[dim] % block_len != 0:
  69. x = _pad_to_multiple(x, block_len, dim, pad_value=0)
  70. num_blocks = x.shape[dim] // block_len
  71. output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
  72. # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
  73. if 0 in output_shape:
  74. return torch.empty(output_shape, dtype=x.dtype, device=x.device)
  75. return x.reshape(output_shape)
  76. def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
  77. """Concatenate three consecutive blocks for each input block for local attentiont.
  78. For more information, see: https://arxiv.org/pdf/2112.07916.pdf.
  79. """
  80. num_blocks = x.shape[block_dim]
  81. pad = [(0, 0)] * x.ndim
  82. pad[block_dim] = (1, 1)
  83. pad = sum(pad[::-1], ())
  84. # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
  85. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  86. blocks_list: List[torch.Tensor] = []
  87. for i in range(3):
  88. # We use indexing approach here:
  89. # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
  90. indices = [slice(0, None)] * x.ndim
  91. indices[block_dim] = slice(i, i + num_blocks)
  92. indices = tuple(indices)
  93. blocks_list.append(x[indices])
  94. # [batch_size, num_blocks, 3 * block_len, ...]
  95. return torch.cat(blocks_list, dim=sequence_dim)
  96. def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
  97. """Makes 3-blocked relative position ids for local attention."""
  98. position_ids = torch.arange(3 * block_len, dtype=torch.int32)
  99. center_position_ids = position_ids[block_len:-block_len]
  100. # [block_len, 3 * block_len]
  101. relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
  102. return relative_position_ids
  103. def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
  104. """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
  105. relative_position_ids = _make_3block_relative_position_ids(block_len)
  106. locality_mask = torch.abs(relative_position_ids) < block_len
  107. locality_mask = locality_mask[None, None, :, :]
  108. locality_mask = locality_mask.to(local_attention_mask.device)
  109. return torch.logical_and(local_attention_mask, locality_mask)
  110. def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
  111. """Prepare attention mask to be applied for a local attention."""
  112. # [batch_size, num_blocks, block_len]
  113. _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
  114. # [batch_size, num_block, 3 * block_len]
  115. _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
  116. _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
  117. _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
  118. # [batch_size, num_block, block_len, 3 * block_len]
  119. local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
  120. local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
  121. # [batch_size, 1, num_block, block_len, 3 * block_len]
  122. return local_attention_mask.unsqueeze(1).to(device)
  123. def _make_global_fixed_block_ids(
  124. attention_mask: torch.Tensor, global_block_size: int
  125. ) -> Tuple[torch.Tensor, torch.Tensor]:
  126. """Obtain the "fixed block" global id corresponding to each input token.
  127. This implementation is a simlified version of the original Flaxformr implementation adopted from:
  128. https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
  129. In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
  130. the whole fixed block, are assigned to the preceding block.
  131. Padding tokens from the original sequence are represented by -1.
  132. """
  133. batch_size, seq_len = attention_mask.shape[:2]
  134. def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
  135. block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
  136. block_ends = block_ends.to(block_ids.device)
  137. true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
  138. full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
  139. block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
  140. return block_ids
  141. fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
  142. fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
  143. mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
  144. global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
  145. _global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
  146. global_block_ids = torch.where(
  147. global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
  148. )
  149. # set padding tokens to -1
  150. global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
  151. # [batch_size, seq_len]
  152. global_block_ids = handle_orphan_tokens(global_block_ids)
  153. num_globals = seq_len // global_block_size
  154. # [batch_size, seq_len // global_block_size]
  155. if num_globals > 0:
  156. _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
  157. else:
  158. _sequence_block_ids_max = torch.zeros(
  159. batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
  160. )
  161. global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
  162. global_segment_ids = global_segment_ids.to(attention_mask.device)
  163. global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
  164. return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
  165. def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
  166. """Create the relative position tensor for local -> global attention."""
  167. block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
  168. global_seq_len = global_segment_ids.shape[-1]
  169. global_positions = torch.arange(global_seq_len, device=block_ids.device)
  170. side_relative_position = global_positions - block_ids[..., None]
  171. return side_relative_position.type(torch.int64)
  172. def _create_global_aggregates(
  173. hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
  174. ) -> torch.Tensor:
  175. """Compute individual block aggregates by summing over individual blocks."""
  176. # (batch..., seq_len, global_seq_len))
  177. block_ids = block_ids.where(
  178. block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
  179. )
  180. one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
  181. return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
  182. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
  183. class LongT5LayerNorm(nn.Module):
  184. def __init__(self, hidden_size, eps=1e-6):
  185. """
  186. Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
  187. """
  188. super().__init__()
  189. self.weight = nn.Parameter(torch.ones(hidden_size))
  190. self.variance_epsilon = eps
  191. def forward(self, hidden_states):
  192. # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  193. # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
  194. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  195. # half-precision inputs is done in fp32
  196. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  197. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  198. # convert into half-precision if necessary
  199. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  200. hidden_states = hidden_states.to(self.weight.dtype)
  201. return self.weight * hidden_states
  202. try:
  203. from apex.normalization import FusedRMSNorm
  204. LongT5LayerNorm = FusedRMSNorm # noqa
  205. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
  206. except ImportError:
  207. # using the normal LongT5LayerNorm
  208. pass
  209. except Exception:
  210. logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
  211. pass
  212. ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
  213. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
  214. class LongT5DenseActDense(nn.Module):
  215. def __init__(self, config: LongT5Config):
  216. super().__init__()
  217. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  218. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  219. self.dropout = nn.Dropout(config.dropout_rate)
  220. self.act = ACT2FN[config.dense_act_fn]
  221. def forward(self, hidden_states):
  222. hidden_states = self.wi(hidden_states)
  223. hidden_states = self.act(hidden_states)
  224. hidden_states = self.dropout(hidden_states)
  225. if (
  226. isinstance(self.wo.weight, torch.Tensor)
  227. and hidden_states.dtype != self.wo.weight.dtype
  228. and self.wo.weight.dtype != torch.int8
  229. ):
  230. hidden_states = hidden_states.to(self.wo.weight.dtype)
  231. hidden_states = self.wo(hidden_states)
  232. return hidden_states
  233. class LongT5DenseGatedActDense(nn.Module):
  234. def __init__(self, config: LongT5Config):
  235. super().__init__()
  236. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  237. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  238. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  239. self.dropout = nn.Dropout(config.dropout_rate)
  240. self.act = ACT2FN[config.dense_act_fn]
  241. def forward(self, hidden_states):
  242. hidden_gelu = self.act(self.wi_0(hidden_states))
  243. hidden_linear = self.wi_1(hidden_states)
  244. hidden_states = hidden_gelu * hidden_linear
  245. hidden_states = self.dropout(hidden_states)
  246. hidden_states = self.wo(hidden_states)
  247. return hidden_states
  248. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
  249. class LongT5LayerFF(nn.Module):
  250. def __init__(self, config: LongT5Config):
  251. super().__init__()
  252. if config.is_gated_act:
  253. self.DenseReluDense = LongT5DenseGatedActDense(config)
  254. else:
  255. self.DenseReluDense = LongT5DenseActDense(config)
  256. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  257. self.dropout = nn.Dropout(config.dropout_rate)
  258. def forward(self, hidden_states):
  259. forwarded_states = self.layer_norm(hidden_states)
  260. forwarded_states = self.DenseReluDense(forwarded_states)
  261. hidden_states = hidden_states + self.dropout(forwarded_states)
  262. return hidden_states
  263. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
  264. class LongT5Attention(nn.Module):
  265. def __init__(
  266. self,
  267. config: LongT5Config,
  268. has_relative_attention_bias=False,
  269. layer_idx: Optional[int] = None,
  270. ):
  271. super().__init__()
  272. self.is_decoder = config.is_decoder
  273. self.has_relative_attention_bias = has_relative_attention_bias
  274. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  275. self.relative_attention_max_distance = config.relative_attention_max_distance
  276. self.d_model = config.d_model
  277. self.key_value_proj_dim = config.d_kv
  278. self.n_heads = config.num_heads
  279. self.dropout = config.dropout_rate
  280. self.inner_dim = self.n_heads * self.key_value_proj_dim
  281. self.layer_idx = layer_idx
  282. if layer_idx is None and self.is_decoder:
  283. logger.warning_once(
  284. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  285. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  286. "when creating this class."
  287. )
  288. # Mesh TensorFlow initialization to avoid scaling before softmax
  289. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  290. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  291. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  292. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  293. if self.has_relative_attention_bias:
  294. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  295. self.pruned_heads = set()
  296. self.gradient_checkpointing = False
  297. def prune_heads(self, heads):
  298. if len(heads) == 0:
  299. return
  300. heads, index = find_pruneable_heads_and_indices(
  301. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  302. )
  303. # Prune linear layers
  304. self.q = prune_linear_layer(self.q, index)
  305. self.k = prune_linear_layer(self.k, index)
  306. self.v = prune_linear_layer(self.v, index)
  307. self.o = prune_linear_layer(self.o, index, dim=1)
  308. # Update hyper params
  309. self.n_heads = self.n_heads - len(heads)
  310. self.inner_dim = self.key_value_proj_dim * self.n_heads
  311. self.pruned_heads = self.pruned_heads.union(heads)
  312. @staticmethod
  313. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  314. """
  315. Adapted from Mesh Tensorflow:
  316. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  317. Translate relative position to a bucket number for relative attention. The relative position is defined as
  318. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  319. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  320. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  321. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  322. This should allow for more graceful generalization to longer sequences than the model has been trained on
  323. Args:
  324. relative_position: an int32 Tensor
  325. bidirectional: a boolean - whether the attention is bidirectional
  326. num_buckets: an integer
  327. max_distance: an integer
  328. Returns:
  329. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  330. """
  331. relative_buckets = 0
  332. if bidirectional:
  333. num_buckets //= 2
  334. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  335. relative_position = torch.abs(relative_position)
  336. else:
  337. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  338. # now relative_position is in the range [0, inf)
  339. # half of the buckets are for exact increments in positions
  340. max_exact = num_buckets // 2
  341. is_small = relative_position < max_exact
  342. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  343. relative_position_if_large = max_exact + (
  344. torch.log(relative_position.float() / max_exact)
  345. / math.log(max_distance / max_exact)
  346. * (num_buckets - max_exact)
  347. ).to(torch.long)
  348. relative_position_if_large = torch.min(
  349. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  350. )
  351. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  352. return relative_buckets
  353. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  354. """Compute binned relative position bias"""
  355. if device is None:
  356. device = self.relative_attention_bias.weight.device
  357. if cache_position is None:
  358. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  359. else:
  360. context_position = cache_position[:, None].to(device)
  361. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  362. relative_position = memory_position - context_position # shape (query_length, key_length)
  363. relative_position_bucket = self._relative_position_bucket(
  364. relative_position, # shape (query_length, key_length)
  365. bidirectional=(not self.is_decoder),
  366. num_buckets=self.relative_attention_num_buckets,
  367. max_distance=self.relative_attention_max_distance,
  368. )
  369. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  370. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  371. return values
  372. def forward(
  373. self,
  374. hidden_states,
  375. mask=None,
  376. key_value_states=None,
  377. position_bias=None,
  378. past_key_value=None,
  379. layer_head_mask=None,
  380. query_length=None,
  381. use_cache=False,
  382. output_attentions=False,
  383. cache_position=None,
  384. ):
  385. """
  386. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  387. """
  388. # Input is (batch_size, seq_length, dim)
  389. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  390. batch_size, seq_length = hidden_states.shape[:2]
  391. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  392. is_cross_attention = key_value_states is not None
  393. query_states = self.q(hidden_states)
  394. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  395. if past_key_value is not None:
  396. is_updated = past_key_value.is_updated.get(self.layer_idx)
  397. if is_cross_attention:
  398. # after the first generated id, we can subsequently re-use all key/value_states from cache
  399. curr_past_key_value = past_key_value.cross_attention_cache
  400. else:
  401. curr_past_key_value = past_key_value.self_attention_cache
  402. current_states = key_value_states if is_cross_attention else hidden_states
  403. if is_cross_attention and past_key_value is not None and is_updated:
  404. # reuse k,v, cross_attentions
  405. key_states = curr_past_key_value.key_cache[self.layer_idx]
  406. value_states = curr_past_key_value.value_cache[self.layer_idx]
  407. else:
  408. key_states = self.k(current_states)
  409. value_states = self.v(current_states)
  410. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  411. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  412. if past_key_value is not None:
  413. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  414. cache_position = cache_position if not is_cross_attention else None
  415. key_states, value_states = curr_past_key_value.update(
  416. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  417. )
  418. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  419. if is_cross_attention:
  420. past_key_value.is_updated[self.layer_idx] = True
  421. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  422. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  423. if position_bias is None:
  424. key_length = key_states.shape[-2]
  425. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  426. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  427. if not self.has_relative_attention_bias:
  428. position_bias = torch.zeros(
  429. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  430. )
  431. if self.gradient_checkpointing and self.training:
  432. position_bias.requires_grad = True
  433. else:
  434. position_bias = self.compute_bias(
  435. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  436. )
  437. position_bias = position_bias[:, :, -seq_length:, :]
  438. if mask is not None:
  439. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  440. position_bias = position_bias + causal_mask
  441. if self.pruned_heads:
  442. mask = torch.ones(position_bias.shape[1])
  443. mask[list(self.pruned_heads)] = 0
  444. position_bias_masked = position_bias[:, mask.bool()]
  445. else:
  446. position_bias_masked = position_bias
  447. scores += position_bias_masked
  448. # (batch_size, n_heads, seq_length, key_length)
  449. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  450. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  451. # Mask heads if we want to
  452. if layer_head_mask is not None:
  453. attn_weights = attn_weights * layer_head_mask
  454. attn_output = torch.matmul(attn_weights, value_states)
  455. attn_output = attn_output.transpose(1, 2).contiguous()
  456. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  457. attn_output = self.o(attn_output)
  458. outputs = (attn_output, past_key_value, position_bias)
  459. if output_attentions:
  460. outputs = outputs + (attn_weights,)
  461. return outputs
  462. class LongT5LocalAttention(nn.Module):
  463. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  464. super().__init__()
  465. self.is_decoder = config.is_decoder
  466. self.has_relative_attention_bias = has_relative_attention_bias
  467. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  468. self.relative_attention_max_distance = config.relative_attention_max_distance
  469. self.d_model = config.d_model
  470. self.key_value_proj_dim = config.d_kv
  471. self.n_heads = config.num_heads
  472. self.local_radius = config.local_radius
  473. self.block_len = self.local_radius + 1
  474. self.dropout = config.dropout_rate
  475. self.inner_dim = self.n_heads * self.key_value_proj_dim
  476. # Mesh TensorFlow initialization to avoid scaling before softmax
  477. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  478. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  479. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  480. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  481. if self.has_relative_attention_bias:
  482. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  483. self.pruned_heads = set()
  484. self.gradient_checkpointing = False
  485. # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
  486. def prune_heads(self, heads):
  487. if len(heads) == 0:
  488. return
  489. heads, index = find_pruneable_heads_and_indices(
  490. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  491. )
  492. # Prune linear layers
  493. self.q = prune_linear_layer(self.q, index)
  494. self.k = prune_linear_layer(self.k, index)
  495. self.v = prune_linear_layer(self.v, index)
  496. self.o = prune_linear_layer(self.o, index, dim=1)
  497. # Update hyper params
  498. self.n_heads = self.n_heads - len(heads)
  499. self.inner_dim = self.key_value_proj_dim * self.n_heads
  500. self.pruned_heads = self.pruned_heads.union(heads)
  501. @staticmethod
  502. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  503. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  504. """
  505. Adapted from Mesh Tensorflow:
  506. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  507. Translate relative position to a bucket number for relative attention. The relative position is defined as
  508. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  509. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  510. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  511. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  512. This should allow for more graceful generalization to longer sequences than the model has been trained on
  513. Args:
  514. relative_position: an int32 Tensor
  515. bidirectional: a boolean - whether the attention is bidirectional
  516. num_buckets: an integer
  517. max_distance: an integer
  518. Returns:
  519. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  520. """
  521. relative_buckets = 0
  522. if bidirectional:
  523. num_buckets //= 2
  524. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  525. relative_position = torch.abs(relative_position)
  526. else:
  527. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  528. # now relative_position is in the range [0, inf)
  529. # half of the buckets are for exact increments in positions
  530. max_exact = num_buckets // 2
  531. is_small = relative_position < max_exact
  532. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  533. relative_position_if_large = max_exact + (
  534. torch.log(relative_position.float() / max_exact)
  535. / math.log(max_distance / max_exact)
  536. * (num_buckets - max_exact)
  537. ).to(torch.long)
  538. relative_position_if_large = torch.min(
  539. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  540. )
  541. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  542. return relative_buckets
  543. def compute_bias(self, block_length: int):
  544. """Compute binned relative position bias"""
  545. target_device = (
  546. self.relative_attention_bias.weight.device
  547. if self.relative_attention_bias.weight.device.type != "meta"
  548. else None
  549. )
  550. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  551. context_position = memory_position[block_length:-block_length]
  552. # (block_length, 3 * block_length)
  553. relative_position = memory_position[None, :] - context_position[:, None]
  554. relative_position_bucket = self._relative_position_bucket(
  555. relative_position, # (block_length, 3 * block_length)
  556. bidirectional=(not self.is_decoder),
  557. num_buckets=self.relative_attention_num_buckets,
  558. max_distance=self.relative_attention_max_distance,
  559. )
  560. # (block_length, 3 * block_length, num_heads)
  561. values = self.relative_attention_bias(relative_position_bucket)
  562. # (1, 1, num_heads, block_length, 3 * block_length)
  563. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  564. return values
  565. def forward(
  566. self,
  567. hidden_states,
  568. mask=None,
  569. position_bias=None,
  570. layer_head_mask=None,
  571. output_attentions=False,
  572. ):
  573. batch_size, seq_length = hidden_states.shape[:2]
  574. def shape(states):
  575. """projection"""
  576. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  577. def unshape(states):
  578. """reshape"""
  579. return states.contiguous().view(batch_size, -1, self.inner_dim)
  580. # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
  581. query_states = shape(self.q(hidden_states))
  582. key_states = shape(self.k(hidden_states))
  583. value_states = shape(self.v(hidden_states))
  584. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  585. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  586. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  587. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  588. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  589. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  590. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  591. # Compute scores
  592. scores = torch.einsum(
  593. "...qhd,...khd->...hqk", query_states, key_states
  594. ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
  595. if position_bias is None:
  596. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  597. if not self.has_relative_attention_bias:
  598. position_bias = torch.zeros(
  599. (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
  600. )
  601. if self.gradient_checkpointing and self.training:
  602. position_bias.requires_grad = True
  603. else:
  604. position_bias = self.compute_bias(self.block_len)
  605. if mask is not None:
  606. # Replace masked positions with -1e10 (according to the original implementation)
  607. mask = torch.where(mask > 0, 0.0, -1e10)
  608. # We need to adjust position bias shape to be sum with mask
  609. position_bias = position_bias + mask.transpose(1, 2)
  610. scores += position_bias
  611. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  612. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  613. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  614. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  615. # Mask heads if we want to
  616. if layer_head_mask is not None:
  617. attn_weights = attn_weights * layer_head_mask
  618. attn_weights = attn_weights.type(value_states.dtype)
  619. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  620. attn_output = attn_output[:, :seq_length, :]
  621. attn_output = self.o(attn_output)
  622. present_key_value_state = None
  623. outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
  624. if output_attentions:
  625. outputs = outputs + (attn_weights,)
  626. return outputs
  627. class LongT5TransientGlobalAttention(nn.Module):
  628. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  629. super().__init__()
  630. self.is_decoder = config.is_decoder
  631. self.has_relative_attention_bias = has_relative_attention_bias
  632. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  633. self.relative_attention_max_distance = config.relative_attention_max_distance
  634. self.d_model = config.d_model
  635. self.key_value_proj_dim = config.d_kv
  636. self.n_heads = config.num_heads
  637. self.local_radius = config.local_radius
  638. self.block_len = self.local_radius + 1
  639. self.global_block_size = config.global_block_size
  640. self.dropout = config.dropout_rate
  641. self.inner_dim = self.n_heads * self.key_value_proj_dim
  642. # Mesh TensorFlow initialization to avoid scaling before softmax
  643. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  644. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  645. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  646. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  647. if self.has_relative_attention_bias:
  648. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  649. self.pruned_heads = set()
  650. # Relativen attention bias & Layer norm for global attention
  651. if self.has_relative_attention_bias:
  652. self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  653. self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  654. # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
  655. def prune_heads(self, heads):
  656. if len(heads) == 0:
  657. return
  658. heads, index = find_pruneable_heads_and_indices(
  659. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  660. )
  661. # Prune linear layers
  662. self.q = prune_linear_layer(self.q, index)
  663. self.k = prune_linear_layer(self.k, index)
  664. self.v = prune_linear_layer(self.v, index)
  665. self.o = prune_linear_layer(self.o, index, dim=1)
  666. # Update hyper params
  667. self.n_heads = self.n_heads - len(heads)
  668. self.inner_dim = self.key_value_proj_dim * self.n_heads
  669. self.pruned_heads = self.pruned_heads.union(heads)
  670. @staticmethod
  671. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  672. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  673. """
  674. Adapted from Mesh Tensorflow:
  675. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  676. Translate relative position to a bucket number for relative attention. The relative position is defined as
  677. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  678. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  679. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  680. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  681. This should allow for more graceful generalization to longer sequences than the model has been trained on
  682. Args:
  683. relative_position: an int32 Tensor
  684. bidirectional: a boolean - whether the attention is bidirectional
  685. num_buckets: an integer
  686. max_distance: an integer
  687. Returns:
  688. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  689. """
  690. relative_buckets = 0
  691. if bidirectional:
  692. num_buckets //= 2
  693. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  694. relative_position = torch.abs(relative_position)
  695. else:
  696. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  697. # now relative_position is in the range [0, inf)
  698. # half of the buckets are for exact increments in positions
  699. max_exact = num_buckets // 2
  700. is_small = relative_position < max_exact
  701. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  702. relative_position_if_large = max_exact + (
  703. torch.log(relative_position.float() / max_exact)
  704. / math.log(max_distance / max_exact)
  705. * (num_buckets - max_exact)
  706. ).to(torch.long)
  707. relative_position_if_large = torch.min(
  708. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  709. )
  710. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  711. return relative_buckets
  712. def compute_bias(self, block_length: int):
  713. """Compute binned relative position bias"""
  714. target_device = (
  715. self.relative_attention_bias.weight.device
  716. if self.relative_attention_bias.weight.device.type != "meta"
  717. else None
  718. )
  719. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  720. context_position = memory_position[block_length:-block_length]
  721. # (block_length, 3 * block_length)
  722. relative_position = memory_position[None, :] - context_position[:, None]
  723. relative_position_bucket = self._relative_position_bucket(
  724. relative_position, # (block_length, 3 * block_length)
  725. bidirectional=(not self.is_decoder),
  726. num_buckets=self.relative_attention_num_buckets,
  727. max_distance=self.relative_attention_max_distance,
  728. )
  729. # (block_length, 3 * block_length, num_heads)
  730. values = self.relative_attention_bias(relative_position_bucket)
  731. # (1, 1, num_heads, block_length, 3 * block_length)
  732. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  733. return values
  734. def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
  735. # (batch_size, 1, seq_len, global_seq_len)
  736. side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
  737. attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
  738. # (batch_size, seq_len, global_seq_len)
  739. side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
  740. side_relative_position_bucket = self._relative_position_bucket(
  741. side_relative_position,
  742. bidirectional=(not self.is_decoder),
  743. num_buckets=self.relative_attention_num_buckets,
  744. max_distance=self.relative_attention_max_distance,
  745. )
  746. # (batch_size, seq_len, global_seq_len, num_heads)
  747. side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
  748. # (batch_size, num_heads, seq_len, global_seq_len)
  749. side_bias = side_bias.permute([0, 3, 1, 2])
  750. # (batch_size, num_heads, seq_len, global_seq_len)
  751. attention_side_bias = attention_side_bias + side_bias
  752. return attention_side_bias
  753. def forward(
  754. self,
  755. hidden_states,
  756. mask=None,
  757. position_bias=None,
  758. layer_head_mask=None,
  759. output_attentions=False,
  760. ):
  761. batch_size, seq_length = hidden_states.shape[:2]
  762. def shape(states):
  763. """projection"""
  764. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  765. def unshape(states):
  766. """reshape"""
  767. return states.contiguous().view(batch_size, -1, self.inner_dim)
  768. # Prepare components for transient-global attention
  769. # Obtain block_ids and global_segment_ids
  770. # global_seq_len := seq_len // self.global_block_size
  771. # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
  772. block_ids, global_segment_ids = _make_global_fixed_block_ids(
  773. mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
  774. self.global_block_size,
  775. )
  776. # Create global inputs
  777. _global_seq_len = global_segment_ids.shape[-1]
  778. global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
  779. global_inputs = self.global_input_layer_norm(global_inputs)
  780. # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
  781. query_states = shape(self.q(hidden_states))
  782. key_states = shape(self.k(hidden_states))
  783. value_states = shape(self.v(hidden_states))
  784. # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
  785. side_key_states = shape(self.k(global_inputs))
  786. side_value_states = shape(self.v(global_inputs))
  787. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  788. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  789. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  790. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  791. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  792. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  793. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  794. # Tile side inputs across local key/value blocks
  795. # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
  796. reps = [1] * (side_key_states.ndim + 1)
  797. reps[1] = key_states.shape[1]
  798. side_key_states = side_key_states.unsqueeze(1).repeat(reps)
  799. side_value_states = side_value_states.unsqueeze(1).repeat(reps)
  800. # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
  801. # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
  802. key_states = torch.cat([key_states, side_key_states], dim=2)
  803. value_states = torch.cat([value_states, side_value_states], dim=2)
  804. # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
  805. scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
  806. if mask is not None:
  807. # We need to adjust position bias shape to be sum with mask
  808. local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
  809. # Replace masked positions with -10_000 (according to the original implementation)
  810. local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
  811. else:
  812. local_attention_mask = None
  813. if position_bias is None:
  814. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  815. if not self.has_relative_attention_bias:
  816. position_bias = torch.zeros(
  817. (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
  818. device=scores.device,
  819. dtype=scores.dtype,
  820. )
  821. if self.gradient_checkpointing and self.training:
  822. position_bias.requires_grad = True
  823. else:
  824. position_bias = self.compute_bias(self.block_len)
  825. if local_attention_mask is not None:
  826. # (batch_size, 1, n_heads, block_len, 3 * block_len)
  827. position_bias = position_bias + local_attention_mask.transpose(1, 2)
  828. position_bias = position_bias.type(scores.dtype)
  829. # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
  830. if mask is None:
  831. mask = torch.ones(batch_size, seq_length)
  832. # (batch_size, num_heads, seq_len, global_seq_len)
  833. side_position_bias = self.compute_side_bias(mask, global_segment_ids)
  834. # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
  835. side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
  836. side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
  837. # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
  838. position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
  839. scores += position_bias
  840. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
  841. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  842. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  843. # Mask heads if we want to
  844. if layer_head_mask is not None:
  845. attn_weights = attn_weights * layer_head_mask
  846. attn_weights = attn_weights.type(value_states.dtype)
  847. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  848. attn_output = attn_output[:, :seq_length, :]
  849. attn_output = self.o(attn_output)
  850. present_key_value_state = None
  851. outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
  852. if output_attentions:
  853. outputs = outputs + (attn_weights,)
  854. return outputs
  855. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
  856. class LongT5LayerSelfAttention(nn.Module):
  857. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  858. super().__init__()
  859. self.SelfAttention = LongT5Attention(
  860. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  861. )
  862. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  863. self.dropout = nn.Dropout(config.dropout_rate)
  864. def forward(
  865. self,
  866. hidden_states,
  867. attention_mask=None,
  868. position_bias=None,
  869. layer_head_mask=None,
  870. past_key_value=None,
  871. use_cache=False,
  872. output_attentions=False,
  873. cache_position=None,
  874. ):
  875. normed_hidden_states = self.layer_norm(hidden_states)
  876. attention_output = self.SelfAttention(
  877. normed_hidden_states,
  878. mask=attention_mask,
  879. position_bias=position_bias,
  880. layer_head_mask=layer_head_mask,
  881. past_key_value=past_key_value,
  882. use_cache=use_cache,
  883. output_attentions=output_attentions,
  884. cache_position=cache_position,
  885. )
  886. hidden_states = hidden_states + self.dropout(attention_output[0])
  887. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  888. return outputs
  889. class LongT5LayerLocalSelfAttention(nn.Module):
  890. """Local self attention used in encoder"""
  891. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  892. super().__init__()
  893. self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
  894. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  895. self.dropout = nn.Dropout(config.dropout_rate)
  896. def forward(
  897. self,
  898. hidden_states,
  899. attention_mask=None,
  900. position_bias=None,
  901. layer_head_mask=None,
  902. output_attentions=False,
  903. **kwargs: Any, # to accept past_key_value and use_cache kwargs
  904. ):
  905. normed_hidden_states = self.layer_norm(hidden_states)
  906. attention_output = self.LocalSelfAttention(
  907. normed_hidden_states,
  908. mask=attention_mask,
  909. position_bias=position_bias,
  910. layer_head_mask=layer_head_mask,
  911. output_attentions=output_attentions,
  912. )
  913. hidden_states = hidden_states + self.dropout(attention_output[0])
  914. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  915. return outputs
  916. class LongT5LayerTransientGlobalSelfAttention(nn.Module):
  917. """Transient-Global self attention used in encoder"""
  918. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  919. super().__init__()
  920. self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
  921. config, has_relative_attention_bias=has_relative_attention_bias
  922. )
  923. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  924. self.dropout = nn.Dropout(config.dropout_rate)
  925. def forward(
  926. self,
  927. hidden_states,
  928. attention_mask=None,
  929. position_bias=None,
  930. layer_head_mask=None,
  931. output_attentions=False,
  932. **kwargs: Any, # to accept past_key_value and use_cache kwargs
  933. ):
  934. normed_hidden_states = self.layer_norm(hidden_states)
  935. attention_output = self.TransientGlobalSelfAttention(
  936. normed_hidden_states,
  937. mask=attention_mask,
  938. position_bias=position_bias,
  939. layer_head_mask=layer_head_mask,
  940. output_attentions=output_attentions,
  941. )
  942. hidden_states = hidden_states + self.dropout(attention_output[0])
  943. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  944. return outputs
  945. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
  946. class LongT5LayerCrossAttention(nn.Module):
  947. def __init__(self, config, layer_idx: Optional[int] = None):
  948. super().__init__()
  949. self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  950. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  951. self.dropout = nn.Dropout(config.dropout_rate)
  952. def forward(
  953. self,
  954. hidden_states,
  955. key_value_states,
  956. attention_mask=None,
  957. position_bias=None,
  958. layer_head_mask=None,
  959. past_key_value=None,
  960. use_cache=False,
  961. query_length=None,
  962. output_attentions=False,
  963. cache_position=None,
  964. ):
  965. normed_hidden_states = self.layer_norm(hidden_states)
  966. attention_output = self.EncDecAttention(
  967. normed_hidden_states,
  968. mask=attention_mask,
  969. key_value_states=key_value_states,
  970. position_bias=position_bias,
  971. layer_head_mask=layer_head_mask,
  972. past_key_value=past_key_value,
  973. use_cache=use_cache,
  974. query_length=query_length,
  975. output_attentions=output_attentions,
  976. cache_position=cache_position,
  977. )
  978. layer_output = hidden_states + self.dropout(attention_output[0])
  979. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  980. return outputs
  981. class LongT5Block(nn.Module):
  982. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  983. super().__init__()
  984. self.is_decoder = config.is_decoder
  985. if config.is_decoder:
  986. attention_layer = LongT5LayerSelfAttention
  987. elif config.encoder_attention_type == "local":
  988. attention_layer = LongT5LayerLocalSelfAttention
  989. elif config.encoder_attention_type == "transient-global":
  990. attention_layer = LongT5LayerTransientGlobalSelfAttention
  991. else:
  992. raise ValueError(
  993. "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
  994. f"but got {config.encoder_attention_type}."
  995. )
  996. self.layer = nn.ModuleList()
  997. self.layer.append(
  998. attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  999. )
  1000. if self.is_decoder:
  1001. self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx))
  1002. self.layer.append(LongT5LayerFF(config))
  1003. def forward(
  1004. self,
  1005. hidden_states,
  1006. attention_mask=None,
  1007. position_bias=None,
  1008. encoder_hidden_states=None,
  1009. encoder_attention_mask=None,
  1010. encoder_decoder_position_bias=None,
  1011. layer_head_mask=None,
  1012. cross_attn_layer_head_mask=None,
  1013. past_key_value=None,
  1014. use_cache=False,
  1015. output_attentions=False,
  1016. return_dict=True,
  1017. cache_position=None,
  1018. ):
  1019. self_attention_outputs = self.layer[0](
  1020. hidden_states,
  1021. attention_mask=attention_mask,
  1022. position_bias=position_bias,
  1023. layer_head_mask=layer_head_mask,
  1024. past_key_value=past_key_value,
  1025. use_cache=use_cache,
  1026. output_attentions=output_attentions,
  1027. cache_position=cache_position,
  1028. )
  1029. hidden_states, past_key_value = self_attention_outputs[:2]
  1030. attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
  1031. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1032. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1033. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1034. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1035. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  1036. if do_cross_attention:
  1037. cross_attention_outputs = self.layer[1](
  1038. hidden_states,
  1039. key_value_states=encoder_hidden_states,
  1040. attention_mask=encoder_attention_mask,
  1041. position_bias=encoder_decoder_position_bias,
  1042. layer_head_mask=cross_attn_layer_head_mask,
  1043. past_key_value=past_key_value,
  1044. query_length=cache_position[-1] + 1,
  1045. use_cache=use_cache,
  1046. output_attentions=output_attentions,
  1047. cache_position=cache_position,
  1048. )
  1049. hidden_states, past_key_value = cross_attention_outputs[:2]
  1050. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1051. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1052. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1053. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1054. # Keep cross-attention outputs and relative position weights
  1055. attention_outputs = attention_outputs + cross_attention_outputs[2:]
  1056. # Apply Feed Forward layer
  1057. hidden_states = self.layer[-1](hidden_states)
  1058. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1059. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1060. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1061. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1062. outputs = (hidden_states,)
  1063. if use_cache:
  1064. outputs = outputs + (past_key_value,) + attention_outputs
  1065. else:
  1066. outputs = outputs + attention_outputs
  1067. return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1068. class LongT5PreTrainedModel(PreTrainedModel):
  1069. """
  1070. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1071. models.
  1072. """
  1073. config_class = LongT5Config
  1074. base_model_prefix = "transformer"
  1075. supports_gradient_checkpointing = True
  1076. _no_split_modules = ["LongT5Block"]
  1077. _supports_cache_class = True
  1078. _supports_static_cache = False # TODO: @raushan more involved due to local/global attn
  1079. @property
  1080. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
  1081. def dummy_inputs(self):
  1082. input_ids = torch.tensor(DUMMY_INPUTS)
  1083. input_mask = torch.tensor(DUMMY_MASK)
  1084. dummy_inputs = {
  1085. "decoder_input_ids": input_ids,
  1086. "input_ids": input_ids,
  1087. "decoder_attention_mask": input_mask,
  1088. }
  1089. return dummy_inputs
  1090. def _init_weights(self, module):
  1091. """Initialize the weights"""
  1092. factor = self.config.initializer_factor # Used for testing weights initialization
  1093. if isinstance(module, LongT5LayerNorm):
  1094. module.weight.data.fill_(factor * 1.0)
  1095. elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
  1096. # Mesh TensorFlow embeddings initialization
  1097. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  1098. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  1099. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  1100. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  1101. elif isinstance(module, LongT5DenseActDense):
  1102. # Mesh TensorFlow FF initialization
  1103. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  1104. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  1105. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1106. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  1107. module.wi.bias.data.zero_()
  1108. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  1109. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  1110. module.wo.bias.data.zero_()
  1111. elif isinstance(module, LongT5DenseGatedActDense):
  1112. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1113. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  1114. module.wi_0.bias.data.zero_()
  1115. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1116. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  1117. module.wi_1.bias.data.zero_()
  1118. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  1119. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  1120. module.wo.bias.data.zero_()
  1121. elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
  1122. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  1123. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  1124. d_model = self.config.d_model
  1125. key_value_proj_dim = self.config.d_kv
  1126. n_heads = self.config.num_heads
  1127. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  1128. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  1129. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  1130. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  1131. if module.has_relative_attention_bias:
  1132. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  1133. if isinstance(module, LongT5TransientGlobalAttention):
  1134. module.global_relative_attention_bias.weight.data.normal_(
  1135. mean=0.0, std=factor * ((d_model) ** -0.5)
  1136. )
  1137. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
  1138. def _shift_right(self, input_ids):
  1139. decoder_start_token_id = self.config.decoder_start_token_id
  1140. pad_token_id = self.config.pad_token_id
  1141. if decoder_start_token_id is None:
  1142. raise ValueError(
  1143. "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id. "
  1144. "See LongT5 docs for more information."
  1145. )
  1146. # shift inputs to the right
  1147. if is_torch_fx_proxy(input_ids):
  1148. # Item assignment is not supported natively for proxies.
  1149. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  1150. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  1151. else:
  1152. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  1153. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  1154. shifted_input_ids[..., 0] = decoder_start_token_id
  1155. if pad_token_id is None:
  1156. raise ValueError("self.model.config.pad_token_id has to be defined.")
  1157. # replace possible -100 values in labels by `pad_token_id`
  1158. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  1159. return shifted_input_ids
  1160. class LongT5Stack(LongT5PreTrainedModel):
  1161. def __init__(self, config, embed_tokens=None):
  1162. super().__init__(config)
  1163. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  1164. if embed_tokens is not None:
  1165. self.embed_tokens.weight = embed_tokens.weight
  1166. self.is_decoder = config.is_decoder
  1167. self.local_radius = config.local_radius
  1168. self.block_len = self.local_radius + 1
  1169. self.block = nn.ModuleList(
  1170. [
  1171. LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  1172. for i in range(config.num_layers)
  1173. ]
  1174. )
  1175. self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  1176. self.dropout = nn.Dropout(config.dropout_rate)
  1177. self.gradient_checkpointing = False
  1178. # Initialize weights and apply final processing
  1179. self.post_init()
  1180. # Copied from transformers.models.t5.modeling_t5.T5Stack.get_input_embeddings
  1181. def get_input_embeddings(self):
  1182. return self.embed_tokens
  1183. # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
  1184. def set_input_embeddings(self, new_embeddings):
  1185. self.embed_tokens = new_embeddings
  1186. def forward(
  1187. self,
  1188. input_ids=None,
  1189. attention_mask=None,
  1190. encoder_hidden_states=None,
  1191. encoder_attention_mask=None,
  1192. inputs_embeds=None,
  1193. head_mask=None,
  1194. cross_attn_head_mask=None,
  1195. past_key_values=None,
  1196. use_cache=None,
  1197. output_attentions=None,
  1198. output_hidden_states=None,
  1199. return_dict=None,
  1200. cache_position=None,
  1201. ):
  1202. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1203. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1204. output_hidden_states = (
  1205. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1206. )
  1207. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1208. if input_ids is not None and inputs_embeds is not None:
  1209. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1210. raise ValueError(
  1211. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  1212. )
  1213. elif input_ids is not None:
  1214. input_shape = input_ids.size()
  1215. input_ids = input_ids.view(-1, input_shape[-1])
  1216. elif inputs_embeds is not None:
  1217. input_shape = inputs_embeds.size()[:-1]
  1218. else:
  1219. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1220. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  1221. if self.gradient_checkpointing and self.training:
  1222. if use_cache:
  1223. logger.warning_once(
  1224. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1225. )
  1226. use_cache = False
  1227. if inputs_embeds is None:
  1228. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  1229. inputs_embeds = self.embed_tokens(input_ids)
  1230. batch_size, seq_length = input_shape
  1231. # initialize past_key_values
  1232. return_legacy_cache = False
  1233. return_self_attention_cache = False
  1234. if self.is_decoder and (use_cache or past_key_values is not None):
  1235. if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
  1236. return_self_attention_cache = True
  1237. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
  1238. elif not isinstance(past_key_values, EncoderDecoderCache):
  1239. return_legacy_cache = True
  1240. logger.warning_once(
  1241. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
  1242. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  1243. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  1244. )
  1245. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  1246. elif past_key_values is None:
  1247. past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
  1248. elif not self.is_decoder:
  1249. # do not pass cache object down the line for encoder stack
  1250. # it messes indexing later in decoder-stack because cache object is modified in-place
  1251. past_key_values = None
  1252. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1253. if cache_position is None:
  1254. cache_position = torch.arange(
  1255. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  1256. )
  1257. if attention_mask is None and not is_torchdynamo_compiling():
  1258. # required mask seq length can be calculated via length of past
  1259. mask_seq_length = past_key_values_length + seq_length
  1260. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  1261. if self.is_decoder:
  1262. causal_mask = self._update_causal_mask(
  1263. attention_mask,
  1264. inputs_embeds,
  1265. cache_position,
  1266. past_key_values.self_attention_cache if past_key_values is not None else None,
  1267. output_attentions,
  1268. )
  1269. # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
  1270. elif self.config.encoder_attention_type == "local":
  1271. causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
  1272. else: # we need to use both local attention mask and standard extended mask for transient-global attention
  1273. causal_mask = attention_mask
  1274. # If a 2D or 3D attention mask is provided for the cross-attention
  1275. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1276. if self.is_decoder and encoder_hidden_states is not None:
  1277. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1278. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1279. if encoder_attention_mask is None:
  1280. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  1281. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1282. else:
  1283. encoder_extended_attention_mask = None
  1284. # Prepare head mask if needed
  1285. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  1286. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  1287. all_hidden_states = () if output_hidden_states else None
  1288. all_attentions = () if output_attentions else None
  1289. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  1290. position_bias = None
  1291. encoder_decoder_position_bias = None
  1292. hidden_states = self.dropout(inputs_embeds)
  1293. for i, layer_module in enumerate(self.block):
  1294. layer_head_mask = head_mask[i]
  1295. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  1296. if output_hidden_states:
  1297. all_hidden_states = all_hidden_states + (hidden_states,)
  1298. if self.gradient_checkpointing and self.training:
  1299. layer_outputs = self._gradient_checkpointing_func(
  1300. layer_module.forward,
  1301. hidden_states,
  1302. causal_mask,
  1303. position_bias,
  1304. encoder_hidden_states,
  1305. encoder_extended_attention_mask,
  1306. encoder_decoder_position_bias,
  1307. layer_head_mask,
  1308. cross_attn_layer_head_mask,
  1309. None, # past_key_value is always None with gradient checkpointing
  1310. use_cache,
  1311. output_attentions,
  1312. return_dict,
  1313. cache_position,
  1314. )
  1315. else:
  1316. layer_outputs = layer_module(
  1317. hidden_states,
  1318. attention_mask=causal_mask,
  1319. position_bias=position_bias,
  1320. encoder_hidden_states=encoder_hidden_states,
  1321. encoder_attention_mask=encoder_extended_attention_mask,
  1322. encoder_decoder_position_bias=encoder_decoder_position_bias,
  1323. layer_head_mask=layer_head_mask,
  1324. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  1325. past_key_value=past_key_values,
  1326. use_cache=use_cache,
  1327. output_attentions=output_attentions,
  1328. return_dict=return_dict,
  1329. cache_position=cache_position,
  1330. )
  1331. # layer_outputs is a tuple with:
  1332. # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1333. if use_cache is False:
  1334. layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
  1335. hidden_states, next_decoder_cache = layer_outputs[:2]
  1336. # We share the position biases between the layers - the first layer store them
  1337. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1338. # (cross-attention position bias), (cross-attention weights)
  1339. position_bias = layer_outputs[2]
  1340. if self.is_decoder and encoder_hidden_states is not None:
  1341. encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
  1342. if output_attentions:
  1343. all_attentions = all_attentions + (layer_outputs[3],)
  1344. if self.is_decoder:
  1345. all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
  1346. hidden_states = self.final_layer_norm(hidden_states)
  1347. hidden_states = self.dropout(hidden_states)
  1348. # Add last layer
  1349. if output_hidden_states:
  1350. all_hidden_states = all_hidden_states + (hidden_states,)
  1351. next_cache = next_decoder_cache if use_cache else None
  1352. if return_self_attention_cache:
  1353. next_cache = past_key_values.self_attention_cache
  1354. if return_legacy_cache:
  1355. next_cache = past_key_values.to_legacy_cache()
  1356. if not return_dict:
  1357. return tuple(
  1358. v
  1359. for v in [
  1360. hidden_states,
  1361. next_cache,
  1362. all_hidden_states,
  1363. all_attentions,
  1364. all_cross_attentions,
  1365. ]
  1366. if v is not None
  1367. )
  1368. return BaseModelOutputWithPastAndCrossAttentions(
  1369. last_hidden_state=hidden_states,
  1370. past_key_values=next_cache,
  1371. hidden_states=all_hidden_states,
  1372. attentions=all_attentions,
  1373. cross_attentions=all_cross_attentions,
  1374. )
  1375. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  1376. def _update_causal_mask(
  1377. self,
  1378. attention_mask: torch.Tensor,
  1379. input_tensor: torch.Tensor,
  1380. cache_position: torch.Tensor,
  1381. past_key_values: Cache,
  1382. output_attentions: bool,
  1383. ):
  1384. if self.config._attn_implementation == "flash_attention_2":
  1385. if attention_mask is not None and 0.0 in attention_mask:
  1386. return attention_mask
  1387. return None
  1388. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1389. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1390. # to infer the attention mask.
  1391. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1392. using_static_cache = isinstance(past_key_values, StaticCache)
  1393. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1394. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  1395. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1396. attention_mask,
  1397. inputs_embeds=input_tensor,
  1398. past_key_values_length=past_seen_tokens,
  1399. is_training=self.training,
  1400. ):
  1401. return None
  1402. dtype, device = input_tensor.dtype, input_tensor.device
  1403. sequence_length = input_tensor.shape[1]
  1404. if using_static_cache:
  1405. target_length = past_key_values.get_max_cache_shape()
  1406. else:
  1407. target_length = (
  1408. attention_mask.shape[-1]
  1409. if isinstance(attention_mask, torch.Tensor)
  1410. else past_seen_tokens + sequence_length + 1
  1411. )
  1412. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1413. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1414. attention_mask,
  1415. sequence_length=sequence_length,
  1416. target_length=target_length,
  1417. dtype=dtype,
  1418. device=device,
  1419. cache_position=cache_position,
  1420. batch_size=input_tensor.shape[0],
  1421. )
  1422. if (
  1423. self.config._attn_implementation == "sdpa"
  1424. and attention_mask is not None
  1425. and attention_mask.device.type == "cuda"
  1426. and not output_attentions
  1427. ):
  1428. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1429. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1430. # Details: https://github.com/pytorch/pytorch/issues/110213
  1431. min_dtype = torch.finfo(dtype).min
  1432. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1433. return causal_mask
  1434. @staticmethod
  1435. # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
  1436. def _prepare_4d_causal_attention_mask_with_cache_position(
  1437. attention_mask: torch.Tensor,
  1438. sequence_length: int,
  1439. target_length: int,
  1440. dtype: torch.dtype,
  1441. device: torch.device,
  1442. cache_position: torch.Tensor,
  1443. batch_size: int,
  1444. **kwargs,
  1445. ):
  1446. """
  1447. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1448. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1449. Args:
  1450. attention_mask (`torch.Tensor`):
  1451. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1452. `(batch_size, 1, query_length, key_value_length)`.
  1453. sequence_length (`int`):
  1454. The sequence length being processed.
  1455. target_length (`int`):
  1456. The target length: when generating with static cache, the mask should be as long as the static cache,
  1457. to account for the 0 padding, the part of the cache that is not filled yet.
  1458. dtype (`torch.dtype`):
  1459. The dtype to use for the 4D attention mask.
  1460. device (`torch.device`):
  1461. The device to plcae the 4D attention mask on.
  1462. cache_position (`torch.Tensor`):
  1463. Indices depicting the position of the input sequence tokens in the sequence.
  1464. batch_size (`torch.Tensor`):
  1465. Batch size.
  1466. """
  1467. if attention_mask is not None and attention_mask.dim() == 4:
  1468. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1469. causal_mask = attention_mask
  1470. else:
  1471. min_dtype = torch.finfo(dtype).min
  1472. causal_mask = torch.full(
  1473. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1474. )
  1475. if sequence_length != 1:
  1476. causal_mask = torch.triu(causal_mask, diagonal=1)
  1477. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1478. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1479. if attention_mask is not None:
  1480. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1481. mask_length = attention_mask.shape[-1]
  1482. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1483. padding_mask = padding_mask == 0
  1484. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1485. padding_mask, min_dtype
  1486. )
  1487. return causal_mask
  1488. LONGT5_START_DOCSTRING = r"""
  1489. The LongT5 model was proposed in [LongT5: Efficient Text-To-Text Transformer for Long
  1490. Sequences](https://arxiv.org/abs/2112.07916) by Mandy Guo, Joshua Ainslie, David Uthus, Santiago Ontanon, Jianmo
  1491. Ni, Yun-Hsuan Sung and Yinfei Yang. It's an encoder-decoder transformer pre-trained in a text-to-text denoising
  1492. generative setting. LongT5 model is an extension of T5 model, and it enables using one of the two different
  1493. efficient attention mechanisms - (1) Local attention, or (2) Transient-Global attention.
  1494. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  1495. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1496. etc.)
  1497. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  1498. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  1499. and behavior.
  1500. Parameters:
  1501. config ([`LongT5Config`]): Model configuration class with all the parameters of the model.
  1502. Initializing with a config file does not load the weights associated with the model, only the
  1503. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1504. """
  1505. LONGT5_INPUTS_DOCSTRING = r"""
  1506. Args:
  1507. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1508. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1509. you should be able to pad the inputs on both the right and the left.
  1510. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1511. [`PreTrainedTokenizer.__call__`] for detail.
  1512. [What are input IDs?](../glossary#input-ids)
  1513. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1514. Training](./longt5#training).
  1515. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1516. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1517. - 1 for tokens that are **not masked**,
  1518. - 0 for tokens that are **masked**.
  1519. [What are attention masks?](../glossary#attention-mask)
  1520. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1521. Indices of decoder input sequence tokens in the vocabulary.
  1522. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1523. [`PreTrainedTokenizer.__call__`] for details.
  1524. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1525. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1526. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1527. `past_key_values`).
  1528. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1529. Training](./longt5#training).
  1530. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1531. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1532. be used by default.
  1533. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1534. Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
  1535. 1]`:
  1536. - 1 indicates the head is **not masked**,
  1537. - 0 indicates the head is **masked**.
  1538. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1539. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1540. 1]`:
  1541. - 1 indicates the head is **not masked**,
  1542. - 0 indicates the head is **masked**.
  1543. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1544. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1545. `[0, 1]`:
  1546. - 1 indicates the head is **not masked**,
  1547. - 0 indicates the head is **masked**.
  1548. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  1549. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
  1550. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
  1551. the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1552. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1553. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1554. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1555. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1556. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1557. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1558. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1559. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1560. model's internal embedding lookup matrix.
  1561. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  1562. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  1563. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  1564. input (see `past_key_values`). This is useful if you want more control over how to convert
  1565. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  1566. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  1567. of `inputs_embeds`.
  1568. use_cache (`bool`, *optional*):
  1569. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1570. `past_key_values`).
  1571. output_attentions (`bool`, *optional*):
  1572. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1573. tensors for more detail.
  1574. output_hidden_states (`bool`, *optional*):
  1575. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1576. more detail.
  1577. return_dict (`bool`, *optional*):
  1578. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1579. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  1580. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  1581. cache in the correct position and to infer the complete sequence length.
  1582. """
  1583. LONGT5_ENCODER_INPUTS_DOCSTRING = r"""
  1584. Args:
  1585. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1586. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1587. you should be able to pad the inputs on both the right and the left.
  1588. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1589. [`PreTrainedTokenizer.__call__`] for detail.
  1590. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1591. Training](./longt5#training).
  1592. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1593. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1594. - 1 for tokens that are **not masked**,
  1595. - 0 for tokens that are **masked**.
  1596. [What are attention masks?](../glossary#attention-mask)
  1597. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1598. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  1599. - 1 indicates the head is **not masked**,
  1600. - 0 indicates the head is **masked**.
  1601. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1602. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1603. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1604. model's internal embedding lookup matrix.
  1605. output_attentions (`bool`, *optional*):
  1606. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1607. tensors for more detail.
  1608. output_hidden_states (`bool`, *optional*):
  1609. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1610. more detail.
  1611. return_dict (`bool`, *optional*):
  1612. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1613. """
  1614. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1615. __HEAD_MASK_WARNING_MSG = """
  1616. The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
  1617. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
  1618. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
  1619. num_heads)`.
  1620. """
  1621. @add_start_docstrings(
  1622. "The bare LONGT5 Model transformer outputting raw hidden-states without any specific head on top.",
  1623. LONGT5_START_DOCSTRING,
  1624. )
  1625. class LongT5Model(LongT5PreTrainedModel):
  1626. _keys_to_ignore_on_load_unexpected = [
  1627. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1628. ]
  1629. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1630. def __init__(self, config: LongT5Config):
  1631. super().__init__(config)
  1632. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1633. encoder_config = copy.deepcopy(config)
  1634. encoder_config.is_decoder = False
  1635. encoder_config.use_cache = False
  1636. encoder_config.is_encoder_decoder = False
  1637. self.encoder = LongT5Stack(encoder_config, self.shared)
  1638. decoder_config = copy.deepcopy(config)
  1639. decoder_config.is_decoder = True
  1640. decoder_config.is_encoder_decoder = False
  1641. decoder_config.num_layers = config.num_decoder_layers
  1642. self.decoder = LongT5Stack(decoder_config, self.shared)
  1643. # Initialize weights and apply final processing
  1644. self.post_init()
  1645. def get_input_embeddings(self):
  1646. return self.shared
  1647. def set_input_embeddings(self, new_embeddings):
  1648. self.shared = new_embeddings
  1649. self.encoder.set_input_embeddings(new_embeddings)
  1650. self.decoder.set_input_embeddings(new_embeddings)
  1651. def _tie_weights(self):
  1652. if self.config.tie_word_embeddings:
  1653. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1654. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1655. def get_encoder(self):
  1656. return self.encoder
  1657. def get_decoder(self):
  1658. return self.decoder
  1659. def _prune_heads(self, heads_to_prune):
  1660. """
  1661. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1662. class PreTrainedModel
  1663. """
  1664. for layer, heads in heads_to_prune.items():
  1665. self.encoder.layer[layer].attention.prune_heads(heads)
  1666. @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
  1667. @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
  1668. def forward(
  1669. self,
  1670. input_ids: Optional[torch.LongTensor] = None,
  1671. attention_mask: Optional[torch.FloatTensor] = None,
  1672. decoder_input_ids: Optional[torch.LongTensor] = None,
  1673. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1674. head_mask: Optional[torch.FloatTensor] = None,
  1675. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1676. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1677. encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1678. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1679. inputs_embeds: Optional[torch.Tensor] = None,
  1680. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1681. use_cache: Optional[bool] = None,
  1682. output_attentions: Optional[bool] = None,
  1683. output_hidden_states: Optional[bool] = None,
  1684. return_dict: Optional[bool] = None,
  1685. cache_position: Optional[torch.LongTensor] = None,
  1686. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1687. r"""
  1688. Returns:
  1689. Example:
  1690. ```python
  1691. >>> from transformers import AutoTokenizer, LongT5Model
  1692. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1693. >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
  1694. >>> # Let's try a very long encoder input.
  1695. >>> input_ids = tokenizer(
  1696. ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1697. ... ).input_ids # Batch size 1
  1698. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1699. >>> # forward pass
  1700. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1701. >>> last_hidden_states = outputs.last_hidden_state
  1702. ```"""
  1703. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1704. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1705. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1706. if head_mask is not None and decoder_head_mask is None:
  1707. if self.config.num_layers == self.config.num_decoder_layers:
  1708. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1709. decoder_head_mask = head_mask
  1710. # Encode if needed (training, first prediction pass)
  1711. if encoder_outputs is None:
  1712. encoder_outputs = self.encoder(
  1713. input_ids=input_ids,
  1714. attention_mask=attention_mask,
  1715. inputs_embeds=inputs_embeds,
  1716. head_mask=head_mask,
  1717. output_attentions=output_attentions,
  1718. output_hidden_states=output_hidden_states,
  1719. return_dict=return_dict,
  1720. )
  1721. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1722. encoder_outputs = BaseModelOutput(
  1723. last_hidden_state=encoder_outputs[0],
  1724. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1725. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1726. )
  1727. hidden_states = encoder_outputs[0]
  1728. # Decode
  1729. decoder_outputs = self.decoder(
  1730. input_ids=decoder_input_ids,
  1731. attention_mask=decoder_attention_mask,
  1732. inputs_embeds=decoder_inputs_embeds,
  1733. past_key_values=past_key_values,
  1734. encoder_hidden_states=hidden_states,
  1735. encoder_attention_mask=attention_mask,
  1736. head_mask=decoder_head_mask,
  1737. cross_attn_head_mask=cross_attn_head_mask,
  1738. use_cache=use_cache,
  1739. output_attentions=output_attentions,
  1740. output_hidden_states=output_hidden_states,
  1741. return_dict=return_dict,
  1742. cache_position=cache_position,
  1743. )
  1744. if not return_dict:
  1745. return decoder_outputs + encoder_outputs
  1746. return Seq2SeqModelOutput(
  1747. last_hidden_state=decoder_outputs.last_hidden_state,
  1748. past_key_values=decoder_outputs.past_key_values,
  1749. decoder_hidden_states=decoder_outputs.hidden_states,
  1750. decoder_attentions=decoder_outputs.attentions,
  1751. cross_attentions=decoder_outputs.cross_attentions,
  1752. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1753. encoder_hidden_states=encoder_outputs.hidden_states,
  1754. encoder_attentions=encoder_outputs.attentions,
  1755. )
  1756. @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING)
  1757. class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
  1758. _keys_to_ignore_on_load_unexpected = [
  1759. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1760. ]
  1761. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1762. def __init__(self, config: LongT5Config):
  1763. super().__init__(config)
  1764. self.model_dim = config.d_model
  1765. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1766. encoder_config = copy.deepcopy(config)
  1767. encoder_config.is_decoder = False
  1768. encoder_config.use_cache = False
  1769. encoder_config.is_encoder_decoder = False
  1770. self.encoder = LongT5Stack(encoder_config, self.shared)
  1771. decoder_config = copy.deepcopy(config)
  1772. decoder_config.is_decoder = True
  1773. decoder_config.is_encoder_decoder = False
  1774. decoder_config.num_layers = config.num_decoder_layers
  1775. self.decoder = LongT5Stack(decoder_config, self.shared)
  1776. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1777. # Initialize weights and apply final processing
  1778. self.post_init()
  1779. def get_input_embeddings(self):
  1780. return self.shared
  1781. def set_input_embeddings(self, new_embeddings):
  1782. self.shared = new_embeddings
  1783. self.encoder.set_input_embeddings(new_embeddings)
  1784. self.decoder.set_input_embeddings(new_embeddings)
  1785. def _tie_weights(self):
  1786. if self.config.tie_word_embeddings:
  1787. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1788. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1789. def set_output_embeddings(self, new_embeddings):
  1790. self.lm_head = new_embeddings
  1791. def get_output_embeddings(self):
  1792. return self.lm_head
  1793. def get_encoder(self):
  1794. return self.encoder
  1795. def get_decoder(self):
  1796. return self.decoder
  1797. @add_start_docstrings_to_model_forward(LONGT5_INPUTS_DOCSTRING)
  1798. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1799. def forward(
  1800. self,
  1801. input_ids: Optional[torch.LongTensor] = None,
  1802. attention_mask: Optional[torch.FloatTensor] = None,
  1803. decoder_input_ids: Optional[torch.LongTensor] = None,
  1804. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1805. head_mask: Optional[torch.FloatTensor] = None,
  1806. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1807. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1808. encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1809. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1810. inputs_embeds: Optional[torch.FloatTensor] = None,
  1811. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1812. labels: Optional[torch.LongTensor] = None,
  1813. use_cache: Optional[bool] = None,
  1814. output_attentions: Optional[bool] = None,
  1815. output_hidden_states: Optional[bool] = None,
  1816. return_dict: Optional[bool] = None,
  1817. cache_position: Optional[torch.LongTensor] = None,
  1818. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1819. r"""
  1820. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1821. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1822. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1823. labels in `[0, ..., config.vocab_size]`
  1824. Returns:
  1825. Examples:
  1826. ```python
  1827. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1828. >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
  1829. >>> model = LongT5ForConditionalGeneration.from_pretrained(
  1830. ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
  1831. ... )
  1832. >>> # Let's try a very long input.
  1833. >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
  1834. >>> input_ids = inputs.input_ids
  1835. >>> outputs = model.generate(input_ids)
  1836. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1837. abstractthe aim of this article is to provide an overview of the literature on the role of dog
  1838. ```"""
  1839. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1840. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1841. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1842. if head_mask is not None and decoder_head_mask is None:
  1843. if self.config.num_layers == self.config.num_decoder_layers:
  1844. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1845. decoder_head_mask = head_mask
  1846. # Encode if needed (training, first prediction pass)
  1847. if encoder_outputs is None:
  1848. # Convert encoder inputs in embeddings if needed
  1849. encoder_outputs = self.encoder(
  1850. input_ids=input_ids,
  1851. attention_mask=attention_mask,
  1852. inputs_embeds=inputs_embeds,
  1853. head_mask=head_mask,
  1854. output_attentions=output_attentions,
  1855. output_hidden_states=output_hidden_states,
  1856. return_dict=return_dict,
  1857. )
  1858. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1859. encoder_outputs = BaseModelOutput(
  1860. last_hidden_state=encoder_outputs[0],
  1861. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1862. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1863. )
  1864. hidden_states = encoder_outputs[0]
  1865. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1866. # get decoder inputs from shifting lm labels to the right
  1867. decoder_input_ids = self._shift_right(labels)
  1868. # Decode
  1869. decoder_outputs = self.decoder(
  1870. input_ids=decoder_input_ids,
  1871. attention_mask=decoder_attention_mask,
  1872. inputs_embeds=decoder_inputs_embeds,
  1873. past_key_values=past_key_values,
  1874. encoder_hidden_states=hidden_states,
  1875. encoder_attention_mask=attention_mask,
  1876. head_mask=decoder_head_mask,
  1877. cross_attn_head_mask=cross_attn_head_mask,
  1878. use_cache=use_cache,
  1879. output_attentions=output_attentions,
  1880. output_hidden_states=output_hidden_states,
  1881. return_dict=return_dict,
  1882. cache_position=cache_position,
  1883. )
  1884. sequence_output = decoder_outputs[0]
  1885. if self.config.tie_word_embeddings:
  1886. # Rescale output before projecting on vocab
  1887. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1888. sequence_output = sequence_output * (self.model_dim**-0.5)
  1889. lm_logits = self.lm_head(sequence_output)
  1890. loss = None
  1891. if labels is not None:
  1892. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1893. labels = labels.to(lm_logits.device)
  1894. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1895. # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  1896. if not return_dict:
  1897. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1898. return ((loss,) + output) if loss is not None else output
  1899. return Seq2SeqLMOutput(
  1900. loss=loss,
  1901. logits=lm_logits,
  1902. past_key_values=decoder_outputs.past_key_values,
  1903. decoder_hidden_states=decoder_outputs.hidden_states,
  1904. decoder_attentions=decoder_outputs.attentions,
  1905. cross_attentions=decoder_outputs.cross_attentions,
  1906. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1907. encoder_hidden_states=encoder_outputs.hidden_states,
  1908. encoder_attentions=encoder_outputs.attentions,
  1909. )
  1910. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1911. return self._shift_right(labels)
  1912. def _reorder_cache(self, past_key_values, beam_idx):
  1913. # if decoder past is not included in output
  1914. # speedy decoding is disabled and no need to reorder
  1915. if past_key_values is None:
  1916. logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
  1917. return past_key_values
  1918. reordered_decoder_past = ()
  1919. for layer_past_states in past_key_values:
  1920. # get the correct batch idx from layer past batch dim
  1921. # batch dim of `past` is at 2nd position
  1922. reordered_layer_past_states = ()
  1923. for layer_past_state in layer_past_states:
  1924. # need to set correct `past` for each of the four key / value states
  1925. reordered_layer_past_states = reordered_layer_past_states + (
  1926. layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
  1927. )
  1928. assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
  1929. assert len(reordered_layer_past_states) == len(layer_past_states)
  1930. reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
  1931. return reordered_decoder_past
  1932. @add_start_docstrings(
  1933. "The bare LONGT5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
  1934. LONGT5_START_DOCSTRING,
  1935. )
  1936. class LongT5EncoderModel(LongT5PreTrainedModel):
  1937. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1938. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  1939. def __init__(self, config: LongT5Config):
  1940. super().__init__(config)
  1941. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1942. encoder_config = copy.deepcopy(config)
  1943. encoder_config.use_cache = False
  1944. encoder_config.is_encoder_decoder = False
  1945. self.encoder = LongT5Stack(encoder_config, self.shared)
  1946. # Initialize weights and apply final processing
  1947. self.post_init()
  1948. def get_input_embeddings(self):
  1949. return self.shared
  1950. def set_input_embeddings(self, new_embeddings):
  1951. self.shared = new_embeddings
  1952. self.encoder.set_input_embeddings(new_embeddings)
  1953. def _tie_weights(self):
  1954. if self.config.tie_word_embeddings:
  1955. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1956. def get_encoder(self):
  1957. return self.encoder
  1958. def _prune_heads(self, heads_to_prune):
  1959. """
  1960. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1961. class PreTrainedModel
  1962. """
  1963. for layer, heads in heads_to_prune.items():
  1964. self.encoder.layer[layer].attention.prune_heads(heads)
  1965. @add_start_docstrings_to_model_forward(LONGT5_ENCODER_INPUTS_DOCSTRING)
  1966. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  1967. def forward(
  1968. self,
  1969. input_ids: Optional[torch.LongTensor] = None,
  1970. attention_mask: Optional[torch.FloatTensor] = None,
  1971. head_mask: Optional[torch.FloatTensor] = None,
  1972. inputs_embeds: Optional[torch.FloatTensor] = None,
  1973. output_attentions: Optional[bool] = None,
  1974. output_hidden_states: Optional[bool] = None,
  1975. return_dict: Optional[bool] = None,
  1976. ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
  1977. r"""
  1978. Returns:
  1979. Example:
  1980. ```python
  1981. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1982. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1983. >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
  1984. >>> input_ids = tokenizer(
  1985. ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
  1986. ... ).input_ids # Batch size 1
  1987. >>> outputs = model(input_ids=input_ids)
  1988. >>> last_hidden_states = outputs.last_hidden_state
  1989. ```"""
  1990. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1991. encoder_outputs = self.encoder(
  1992. input_ids=input_ids,
  1993. attention_mask=attention_mask,
  1994. inputs_embeds=inputs_embeds,
  1995. head_mask=head_mask,
  1996. output_attentions=output_attentions,
  1997. output_hidden_states=output_hidden_states,
  1998. return_dict=return_dict,
  1999. )
  2000. return encoder_outputs