modeling_t5.py 112 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493
  1. # coding=utf-8
  2. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch T5 model."""
  16. import copy
  17. import math
  18. import os
  19. import warnings
  20. from typing import List, Optional, Tuple, Union
  21. import torch
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import AttentionMaskConverter
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. Seq2SeqLMOutput,
  32. Seq2SeqModelOutput,
  33. Seq2SeqQuestionAnsweringModelOutput,
  34. Seq2SeqSequenceClassifierOutput,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import PreTrainedModel
  38. from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
  39. from ...utils import (
  40. DUMMY_INPUTS,
  41. DUMMY_MASK,
  42. add_start_docstrings,
  43. add_start_docstrings_to_model_forward,
  44. is_torch_fx_proxy,
  45. is_torchdynamo_compiling,
  46. logging,
  47. replace_return_docstrings,
  48. )
  49. from ...utils.model_parallel_utils import assert_device_map, get_device_map
  50. from .configuration_t5 import T5Config
  51. logger = logging.get_logger(__name__)
  52. _CONFIG_FOR_DOC = "T5Config"
  53. _CHECKPOINT_FOR_DOC = "google-t5/t5-small"
  54. ####################################################
  55. # This dict contains ids and associated url
  56. # for the pretrained weights provided with the models
  57. ####################################################
  58. ####################################################
  59. # This is a conversion method from TF 1.0 to PyTorch
  60. # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
  61. ####################################################
  62. def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
  63. """Load tf checkpoints in a pytorch model."""
  64. try:
  65. import re
  66. import numpy as np
  67. import tensorflow as tf
  68. except ImportError:
  69. logger.error(
  70. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  71. "https://www.tensorflow.org/install/ for installation instructions."
  72. )
  73. raise
  74. tf_path = os.path.abspath(tf_checkpoint_path)
  75. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  76. # Load weights from TF model
  77. init_vars = tf.train.list_variables(tf_path)
  78. names = []
  79. tf_weights = {}
  80. for name, shape in init_vars:
  81. logger.info(f"Loading TF weight {name} with shape {shape}")
  82. array = tf.train.load_variable(tf_path, name)
  83. names.append(name)
  84. tf_weights[name] = array
  85. for txt_name in names:
  86. name = txt_name.split("/")
  87. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  88. # which are not required for using pretrained model
  89. if any(
  90. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  91. for n in name
  92. ):
  93. logger.info(f"Skipping {'/'.join(name)}")
  94. tf_weights.pop(txt_name, None)
  95. continue
  96. if "_slot_" in name[-1]:
  97. logger.info(f"Skipping {'/'.join(name)}")
  98. tf_weights.pop(txt_name, None)
  99. continue
  100. pointer = model
  101. array = tf_weights[txt_name]
  102. for m_name in name:
  103. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  104. scope_names = re.split(r"_(\d+)", m_name)
  105. else:
  106. scope_names = [m_name]
  107. if scope_names[0] in ["kernel", "scale", "embedding"]:
  108. pointer = getattr(pointer, "weight")
  109. elif scope_names[0] == "self_attention":
  110. pointer = getattr(pointer, "layer")
  111. pointer = pointer[0]
  112. elif scope_names[0] == "enc_dec_attention":
  113. pointer = getattr(pointer, "layer")
  114. pointer = pointer[1]
  115. elif scope_names[0] == "dense_relu_dense":
  116. pointer = getattr(pointer, "layer")
  117. pointer = pointer[2]
  118. elif scope_names[0] == "rms_norm":
  119. if hasattr(pointer, "layer_norm"):
  120. pointer = getattr(pointer, "layer_norm")
  121. elif hasattr(pointer, "final_layer_norm"):
  122. pointer = getattr(pointer, "final_layer_norm")
  123. elif scope_names[0] == "scale":
  124. pointer = getattr(pointer, "weight")
  125. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  126. pointer = getattr(pointer, "bias")
  127. elif scope_names[0] == "squad":
  128. pointer = getattr(pointer, "classifier")
  129. elif scope_names[0] == "decoder" and name[1] == "logits":
  130. continue
  131. elif scope_names[0] == "logits":
  132. pointer = getattr(pointer, "lm_head")
  133. elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
  134. pointer = getattr(pointer, f"wi_{scope_names[1]}")
  135. continue
  136. else:
  137. try:
  138. pointer = getattr(pointer, scope_names[0])
  139. except AttributeError:
  140. logger.info(f"Skipping {'/'.join(name)}")
  141. continue
  142. if len(scope_names) >= 2:
  143. num = int(scope_names[1])
  144. pointer = pointer[num]
  145. if scope_names[0] not in ["kernel", "scale", "embedding"]:
  146. pointer = getattr(pointer, "weight")
  147. if scope_names[0] != "embedding":
  148. logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
  149. array = np.transpose(array)
  150. try:
  151. if pointer.shape != array.shape:
  152. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  153. except AssertionError as e:
  154. e.args += (pointer.shape, array.shape)
  155. raise
  156. logger.info(f"Initialize PyTorch weight {name}")
  157. pointer.data = torch.from_numpy(array.astype(np.float32))
  158. tf_weights.pop(txt_name, None)
  159. logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
  160. return model
  161. ####################################################
  162. # PyTorch Models are constructed by sub-classing
  163. # - torch.nn.Module for the layers and
  164. # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
  165. ####################################################
  166. PARALLELIZE_DOCSTRING = r"""
  167. This is an experimental feature and is a subject to change at a moment's notice.
  168. Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
  169. it will evenly distribute blocks across all devices.
  170. Args:
  171. device_map (`Dict[int, list]`, *optional*):
  172. A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
  173. automatically mapped to the first device (for esoteric reasons). That means that the first device should
  174. have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
  175. following number of attention modules:
  176. - google-t5/t5-small: 6
  177. - google-t5/t5-base: 12
  178. - google-t5/t5-large: 24
  179. - google-t5/t5-3b: 24
  180. - google-t5/t5-11b: 24
  181. Example:
  182. ```python
  183. # Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules:
  184. model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
  185. device_map = {
  186. 0: [0, 1, 2],
  187. 1: [3, 4, 5, 6, 7, 8, 9],
  188. 2: [10, 11, 12, 13, 14, 15, 16],
  189. 3: [17, 18, 19, 20, 21, 22, 23],
  190. }
  191. model.parallelize(device_map)
  192. ```
  193. """
  194. DEPARALLELIZE_DOCSTRING = r"""
  195. Moves the model to cpu from a model parallel state.
  196. Example:
  197. ```python
  198. # On a 4 GPU machine with google-t5/t5-3b:
  199. model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
  200. device_map = {
  201. 0: [0, 1, 2],
  202. 1: [3, 4, 5, 6, 7, 8, 9],
  203. 2: [10, 11, 12, 13, 14, 15, 16],
  204. 3: [17, 18, 19, 20, 21, 22, 23],
  205. }
  206. model.parallelize(device_map) # Splits the model across several devices
  207. model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
  208. ```
  209. """
  210. class T5LayerNorm(nn.Module):
  211. def __init__(self, hidden_size, eps=1e-6):
  212. """
  213. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  214. """
  215. super().__init__()
  216. self.weight = nn.Parameter(torch.ones(hidden_size))
  217. self.variance_epsilon = eps
  218. def forward(self, hidden_states):
  219. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  220. # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
  221. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  222. # half-precision inputs is done in fp32
  223. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  224. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  225. # convert into half-precision if necessary
  226. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  227. hidden_states = hidden_states.to(self.weight.dtype)
  228. return self.weight * hidden_states
  229. try:
  230. from apex.normalization import FusedRMSNorm
  231. T5LayerNorm = FusedRMSNorm # noqa
  232. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
  233. except ImportError:
  234. # using the normal T5LayerNorm
  235. pass
  236. except Exception:
  237. logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
  238. pass
  239. ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
  240. class T5DenseActDense(nn.Module):
  241. def __init__(self, config: T5Config):
  242. super().__init__()
  243. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  244. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  245. self.dropout = nn.Dropout(config.dropout_rate)
  246. self.act = ACT2FN[config.dense_act_fn]
  247. def forward(self, hidden_states):
  248. hidden_states = self.wi(hidden_states)
  249. hidden_states = self.act(hidden_states)
  250. hidden_states = self.dropout(hidden_states)
  251. if (
  252. isinstance(self.wo.weight, torch.Tensor)
  253. and hidden_states.dtype != self.wo.weight.dtype
  254. and self.wo.weight.dtype != torch.int8
  255. ):
  256. hidden_states = hidden_states.to(self.wo.weight.dtype)
  257. hidden_states = self.wo(hidden_states)
  258. return hidden_states
  259. class T5DenseGatedActDense(nn.Module):
  260. def __init__(self, config: T5Config):
  261. super().__init__()
  262. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  263. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  264. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  265. self.dropout = nn.Dropout(config.dropout_rate)
  266. self.act = ACT2FN[config.dense_act_fn]
  267. def forward(self, hidden_states):
  268. hidden_gelu = self.act(self.wi_0(hidden_states))
  269. hidden_linear = self.wi_1(hidden_states)
  270. hidden_states = hidden_gelu * hidden_linear
  271. hidden_states = self.dropout(hidden_states)
  272. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  273. # See https://github.com/huggingface/transformers/issues/20287
  274. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  275. if (
  276. isinstance(self.wo.weight, torch.Tensor)
  277. and hidden_states.dtype != self.wo.weight.dtype
  278. and self.wo.weight.dtype != torch.int8
  279. ):
  280. hidden_states = hidden_states.to(self.wo.weight.dtype)
  281. hidden_states = self.wo(hidden_states)
  282. return hidden_states
  283. class T5LayerFF(nn.Module):
  284. def __init__(self, config: T5Config):
  285. super().__init__()
  286. if config.is_gated_act:
  287. self.DenseReluDense = T5DenseGatedActDense(config)
  288. else:
  289. self.DenseReluDense = T5DenseActDense(config)
  290. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  291. self.dropout = nn.Dropout(config.dropout_rate)
  292. def forward(self, hidden_states):
  293. forwarded_states = self.layer_norm(hidden_states)
  294. forwarded_states = self.DenseReluDense(forwarded_states)
  295. hidden_states = hidden_states + self.dropout(forwarded_states)
  296. return hidden_states
  297. class T5Attention(nn.Module):
  298. def __init__(
  299. self,
  300. config: T5Config,
  301. has_relative_attention_bias=False,
  302. layer_idx: Optional[int] = None,
  303. ):
  304. super().__init__()
  305. self.is_decoder = config.is_decoder
  306. self.has_relative_attention_bias = has_relative_attention_bias
  307. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  308. self.relative_attention_max_distance = config.relative_attention_max_distance
  309. self.d_model = config.d_model
  310. self.key_value_proj_dim = config.d_kv
  311. self.n_heads = config.num_heads
  312. self.dropout = config.dropout_rate
  313. self.inner_dim = self.n_heads * self.key_value_proj_dim
  314. self.layer_idx = layer_idx
  315. if layer_idx is None and self.is_decoder:
  316. logger.warning_once(
  317. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  318. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  319. "when creating this class."
  320. )
  321. # Mesh TensorFlow initialization to avoid scaling before softmax
  322. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  323. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  324. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  325. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  326. if self.has_relative_attention_bias:
  327. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  328. self.pruned_heads = set()
  329. self.gradient_checkpointing = False
  330. def prune_heads(self, heads):
  331. if len(heads) == 0:
  332. return
  333. heads, index = find_pruneable_heads_and_indices(
  334. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  335. )
  336. # Prune linear layers
  337. self.q = prune_linear_layer(self.q, index)
  338. self.k = prune_linear_layer(self.k, index)
  339. self.v = prune_linear_layer(self.v, index)
  340. self.o = prune_linear_layer(self.o, index, dim=1)
  341. # Update hyper params
  342. self.n_heads = self.n_heads - len(heads)
  343. self.inner_dim = self.key_value_proj_dim * self.n_heads
  344. self.pruned_heads = self.pruned_heads.union(heads)
  345. @staticmethod
  346. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  347. """
  348. Adapted from Mesh Tensorflow:
  349. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  350. Translate relative position to a bucket number for relative attention. The relative position is defined as
  351. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  352. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  353. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  354. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  355. This should allow for more graceful generalization to longer sequences than the model has been trained on
  356. Args:
  357. relative_position: an int32 Tensor
  358. bidirectional: a boolean - whether the attention is bidirectional
  359. num_buckets: an integer
  360. max_distance: an integer
  361. Returns:
  362. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  363. """
  364. relative_buckets = 0
  365. if bidirectional:
  366. num_buckets //= 2
  367. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  368. relative_position = torch.abs(relative_position)
  369. else:
  370. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  371. # now relative_position is in the range [0, inf)
  372. # half of the buckets are for exact increments in positions
  373. max_exact = num_buckets // 2
  374. is_small = relative_position < max_exact
  375. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  376. relative_position_if_large = max_exact + (
  377. torch.log(relative_position.float() / max_exact)
  378. / math.log(max_distance / max_exact)
  379. * (num_buckets - max_exact)
  380. ).to(torch.long)
  381. relative_position_if_large = torch.min(
  382. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  383. )
  384. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  385. return relative_buckets
  386. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  387. """Compute binned relative position bias"""
  388. if device is None:
  389. device = self.relative_attention_bias.weight.device
  390. if cache_position is None:
  391. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  392. else:
  393. context_position = cache_position[:, None].to(device)
  394. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  395. relative_position = memory_position - context_position # shape (query_length, key_length)
  396. relative_position_bucket = self._relative_position_bucket(
  397. relative_position, # shape (query_length, key_length)
  398. bidirectional=(not self.is_decoder),
  399. num_buckets=self.relative_attention_num_buckets,
  400. max_distance=self.relative_attention_max_distance,
  401. )
  402. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  403. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  404. return values
  405. def forward(
  406. self,
  407. hidden_states,
  408. mask=None,
  409. key_value_states=None,
  410. position_bias=None,
  411. past_key_value=None,
  412. layer_head_mask=None,
  413. query_length=None,
  414. use_cache=False,
  415. output_attentions=False,
  416. cache_position=None,
  417. ):
  418. """
  419. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  420. """
  421. # Input is (batch_size, seq_length, dim)
  422. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  423. batch_size, seq_length = hidden_states.shape[:2]
  424. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  425. is_cross_attention = key_value_states is not None
  426. query_states = self.q(hidden_states)
  427. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  428. if past_key_value is not None:
  429. is_updated = past_key_value.is_updated.get(self.layer_idx)
  430. if is_cross_attention:
  431. # after the first generated id, we can subsequently re-use all key/value_states from cache
  432. curr_past_key_value = past_key_value.cross_attention_cache
  433. else:
  434. curr_past_key_value = past_key_value.self_attention_cache
  435. current_states = key_value_states if is_cross_attention else hidden_states
  436. if is_cross_attention and past_key_value is not None and is_updated:
  437. # reuse k,v, cross_attentions
  438. key_states = curr_past_key_value.key_cache[self.layer_idx]
  439. value_states = curr_past_key_value.value_cache[self.layer_idx]
  440. else:
  441. key_states = self.k(current_states)
  442. value_states = self.v(current_states)
  443. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  444. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  445. if past_key_value is not None:
  446. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  447. cache_position = cache_position if not is_cross_attention else None
  448. key_states, value_states = curr_past_key_value.update(
  449. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  450. )
  451. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  452. if is_cross_attention:
  453. past_key_value.is_updated[self.layer_idx] = True
  454. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  455. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  456. if position_bias is None:
  457. key_length = key_states.shape[-2]
  458. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  459. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  460. if not self.has_relative_attention_bias:
  461. position_bias = torch.zeros(
  462. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  463. )
  464. if self.gradient_checkpointing and self.training:
  465. position_bias.requires_grad = True
  466. else:
  467. position_bias = self.compute_bias(
  468. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  469. )
  470. position_bias = position_bias[:, :, -seq_length:, :]
  471. if mask is not None:
  472. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  473. position_bias = position_bias + causal_mask
  474. if self.pruned_heads:
  475. mask = torch.ones(position_bias.shape[1])
  476. mask[list(self.pruned_heads)] = 0
  477. position_bias_masked = position_bias[:, mask.bool()]
  478. else:
  479. position_bias_masked = position_bias
  480. scores += position_bias_masked
  481. # (batch_size, n_heads, seq_length, key_length)
  482. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  483. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  484. # Mask heads if we want to
  485. if layer_head_mask is not None:
  486. attn_weights = attn_weights * layer_head_mask
  487. attn_output = torch.matmul(attn_weights, value_states)
  488. attn_output = attn_output.transpose(1, 2).contiguous()
  489. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  490. attn_output = self.o(attn_output)
  491. outputs = (attn_output, past_key_value, position_bias)
  492. if output_attentions:
  493. outputs = outputs + (attn_weights,)
  494. return outputs
  495. class T5LayerSelfAttention(nn.Module):
  496. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  497. super().__init__()
  498. self.SelfAttention = T5Attention(
  499. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  500. )
  501. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  502. self.dropout = nn.Dropout(config.dropout_rate)
  503. def forward(
  504. self,
  505. hidden_states,
  506. attention_mask=None,
  507. position_bias=None,
  508. layer_head_mask=None,
  509. past_key_value=None,
  510. use_cache=False,
  511. output_attentions=False,
  512. cache_position=None,
  513. ):
  514. normed_hidden_states = self.layer_norm(hidden_states)
  515. attention_output = self.SelfAttention(
  516. normed_hidden_states,
  517. mask=attention_mask,
  518. position_bias=position_bias,
  519. layer_head_mask=layer_head_mask,
  520. past_key_value=past_key_value,
  521. use_cache=use_cache,
  522. output_attentions=output_attentions,
  523. cache_position=cache_position,
  524. )
  525. hidden_states = hidden_states + self.dropout(attention_output[0])
  526. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  527. return outputs
  528. class T5LayerCrossAttention(nn.Module):
  529. def __init__(self, config, layer_idx: Optional[int] = None):
  530. super().__init__()
  531. self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  532. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  533. self.dropout = nn.Dropout(config.dropout_rate)
  534. def forward(
  535. self,
  536. hidden_states,
  537. key_value_states,
  538. attention_mask=None,
  539. position_bias=None,
  540. layer_head_mask=None,
  541. past_key_value=None,
  542. use_cache=False,
  543. query_length=None,
  544. output_attentions=False,
  545. cache_position=None,
  546. ):
  547. normed_hidden_states = self.layer_norm(hidden_states)
  548. attention_output = self.EncDecAttention(
  549. normed_hidden_states,
  550. mask=attention_mask,
  551. key_value_states=key_value_states,
  552. position_bias=position_bias,
  553. layer_head_mask=layer_head_mask,
  554. past_key_value=past_key_value,
  555. use_cache=use_cache,
  556. query_length=query_length,
  557. output_attentions=output_attentions,
  558. cache_position=cache_position,
  559. )
  560. layer_output = hidden_states + self.dropout(attention_output[0])
  561. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  562. return outputs
  563. class T5Block(nn.Module):
  564. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  565. super().__init__()
  566. self.is_decoder = config.is_decoder
  567. self.layer = nn.ModuleList()
  568. self.layer.append(
  569. T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  570. )
  571. if self.is_decoder:
  572. self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
  573. self.layer.append(T5LayerFF(config))
  574. def forward(
  575. self,
  576. hidden_states,
  577. attention_mask=None,
  578. position_bias=None,
  579. encoder_hidden_states=None,
  580. encoder_attention_mask=None,
  581. encoder_decoder_position_bias=None,
  582. layer_head_mask=None,
  583. cross_attn_layer_head_mask=None,
  584. past_key_value=None,
  585. use_cache=False,
  586. output_attentions=False,
  587. return_dict=True,
  588. cache_position=None,
  589. ):
  590. self_attention_outputs = self.layer[0](
  591. hidden_states,
  592. attention_mask=attention_mask,
  593. position_bias=position_bias,
  594. layer_head_mask=layer_head_mask,
  595. past_key_value=past_key_value,
  596. use_cache=use_cache,
  597. output_attentions=output_attentions,
  598. cache_position=cache_position,
  599. )
  600. hidden_states, past_key_value = self_attention_outputs[:2]
  601. attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
  602. # clamp inf values to enable fp16 training
  603. if hidden_states.dtype == torch.float16:
  604. clamp_value = torch.where(
  605. torch.isinf(hidden_states).any(),
  606. torch.finfo(hidden_states.dtype).max - 1000,
  607. torch.finfo(hidden_states.dtype).max,
  608. )
  609. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  610. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  611. if do_cross_attention:
  612. cross_attention_outputs = self.layer[1](
  613. hidden_states,
  614. key_value_states=encoder_hidden_states,
  615. attention_mask=encoder_attention_mask,
  616. position_bias=encoder_decoder_position_bias,
  617. layer_head_mask=cross_attn_layer_head_mask,
  618. past_key_value=past_key_value,
  619. query_length=cache_position[-1] + 1,
  620. use_cache=use_cache,
  621. output_attentions=output_attentions,
  622. )
  623. hidden_states, past_key_value = cross_attention_outputs[:2]
  624. # clamp inf values to enable fp16 training
  625. if hidden_states.dtype == torch.float16:
  626. clamp_value = torch.where(
  627. torch.isinf(hidden_states).any(),
  628. torch.finfo(hidden_states.dtype).max - 1000,
  629. torch.finfo(hidden_states.dtype).max,
  630. )
  631. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  632. # Keep cross-attention outputs and relative position weights
  633. attention_outputs = attention_outputs + cross_attention_outputs[2:]
  634. # Apply Feed Forward layer
  635. hidden_states = self.layer[-1](hidden_states)
  636. # clamp inf values to enable fp16 training
  637. if hidden_states.dtype == torch.float16:
  638. clamp_value = torch.where(
  639. torch.isinf(hidden_states).any(),
  640. torch.finfo(hidden_states.dtype).max - 1000,
  641. torch.finfo(hidden_states.dtype).max,
  642. )
  643. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  644. outputs = (hidden_states,)
  645. if use_cache:
  646. outputs = outputs + (past_key_value,) + attention_outputs
  647. else:
  648. outputs = outputs + attention_outputs
  649. return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  650. class T5ClassificationHead(nn.Module):
  651. """Head for sentence-level classification tasks."""
  652. def __init__(self, config: T5Config):
  653. super().__init__()
  654. self.dense = nn.Linear(config.d_model, config.d_model)
  655. self.dropout = nn.Dropout(p=config.classifier_dropout)
  656. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  657. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  658. hidden_states = self.dropout(hidden_states)
  659. hidden_states = self.dense(hidden_states)
  660. hidden_states = torch.tanh(hidden_states)
  661. hidden_states = self.dropout(hidden_states)
  662. hidden_states = self.out_proj(hidden_states)
  663. return hidden_states
  664. class T5PreTrainedModel(PreTrainedModel):
  665. """
  666. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  667. models.
  668. """
  669. config_class = T5Config
  670. load_tf_weights = load_tf_weights_in_t5
  671. base_model_prefix = "transformer"
  672. is_parallelizable = True
  673. supports_gradient_checkpointing = True
  674. _supports_quantized_cache = False # enc-dec models don't support yet
  675. _supports_static_cache = True
  676. _supports_cache_class = True
  677. _no_split_modules = ["T5Block"]
  678. _keep_in_fp32_modules = ["wo"]
  679. @property
  680. def dummy_inputs(self):
  681. input_ids = torch.tensor(DUMMY_INPUTS)
  682. input_mask = torch.tensor(DUMMY_MASK)
  683. dummy_inputs = {
  684. "decoder_input_ids": input_ids,
  685. "input_ids": input_ids,
  686. "decoder_attention_mask": input_mask,
  687. }
  688. return dummy_inputs
  689. def _init_weights(self, module):
  690. """Initialize the weights"""
  691. factor = self.config.initializer_factor # Used for testing weights initialization
  692. if isinstance(module, T5LayerNorm):
  693. module.weight.data.fill_(factor * 1.0)
  694. elif isinstance(
  695. module,
  696. (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
  697. ):
  698. # Mesh TensorFlow embeddings initialization
  699. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  700. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  701. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  702. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  703. if hasattr(module, "qa_outputs"):
  704. module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  705. module.qa_outputs.bias.data.zero_()
  706. elif isinstance(module, T5ForTokenClassification):
  707. if hasattr(module, "classifier"):
  708. module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
  709. module.classifier.bias.data.zero_()
  710. elif isinstance(module, T5ClassificationHead):
  711. module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  712. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  713. module.dense.bias.data.zero_()
  714. module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  715. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  716. module.out_proj.bias.data.zero_()
  717. elif isinstance(module, T5DenseActDense):
  718. # Mesh TensorFlow FF initialization
  719. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  720. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  721. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  722. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  723. module.wi.bias.data.zero_()
  724. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  725. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  726. module.wo.bias.data.zero_()
  727. elif isinstance(module, T5DenseGatedActDense):
  728. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  729. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  730. module.wi_0.bias.data.zero_()
  731. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  732. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  733. module.wi_1.bias.data.zero_()
  734. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  735. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  736. module.wo.bias.data.zero_()
  737. elif isinstance(module, T5Attention):
  738. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  739. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  740. d_model = self.config.d_model
  741. key_value_proj_dim = self.config.d_kv
  742. n_heads = self.config.num_heads
  743. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  744. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  745. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  746. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  747. if module.has_relative_attention_bias:
  748. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  749. def _shift_right(self, input_ids):
  750. decoder_start_token_id = self.config.decoder_start_token_id
  751. pad_token_id = self.config.pad_token_id
  752. if decoder_start_token_id is None:
  753. raise ValueError(
  754. "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
  755. "See T5 docs for more information."
  756. )
  757. # shift inputs to the right
  758. if is_torch_fx_proxy(input_ids):
  759. # Item assignment is not supported natively for proxies.
  760. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  761. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  762. else:
  763. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  764. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  765. shifted_input_ids[..., 0] = decoder_start_token_id
  766. if pad_token_id is None:
  767. raise ValueError("self.model.config.pad_token_id has to be defined.")
  768. # replace possible -100 values in labels by `pad_token_id`
  769. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  770. return shifted_input_ids
  771. class T5Stack(T5PreTrainedModel):
  772. def __init__(self, config, embed_tokens=None):
  773. super().__init__(config)
  774. self.embed_tokens = embed_tokens
  775. self.is_decoder = config.is_decoder
  776. self.block = nn.ModuleList(
  777. [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
  778. )
  779. self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  780. self.dropout = nn.Dropout(config.dropout_rate)
  781. # Initialize weights and apply final processing
  782. self.post_init()
  783. # Model parallel
  784. self.model_parallel = False
  785. self.device_map = None
  786. self.gradient_checkpointing = False
  787. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  788. def parallelize(self, device_map=None):
  789. warnings.warn(
  790. "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  791. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  792. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  793. " 'block.1': 1, ...}",
  794. FutureWarning,
  795. )
  796. # Check validity of device_map
  797. self.device_map = (
  798. get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
  799. )
  800. assert_device_map(self.device_map, len(self.block))
  801. self.model_parallel = True
  802. self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
  803. self.last_device = "cuda:" + str(max(self.device_map.keys()))
  804. # Load onto devices
  805. for k, v in self.device_map.items():
  806. for layer in v:
  807. cuda_device = "cuda:" + str(k)
  808. self.block[layer] = self.block[layer].to(cuda_device)
  809. # Set embed_tokens to first layer
  810. self.embed_tokens = self.embed_tokens.to(self.first_device)
  811. # Set final layer norm to last device
  812. self.final_layer_norm = self.final_layer_norm.to(self.last_device)
  813. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  814. def deparallelize(self):
  815. warnings.warn(
  816. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  817. FutureWarning,
  818. )
  819. self.model_parallel = False
  820. self.device_map = None
  821. self.first_device = "cpu"
  822. self.last_device = "cpu"
  823. for i in range(len(self.block)):
  824. self.block[i] = self.block[i].to("cpu")
  825. self.embed_tokens = self.embed_tokens.to("cpu")
  826. self.final_layer_norm = self.final_layer_norm.to("cpu")
  827. torch.cuda.empty_cache()
  828. def get_input_embeddings(self):
  829. return self.embed_tokens
  830. def set_input_embeddings(self, new_embeddings):
  831. self.embed_tokens = new_embeddings
  832. def forward(
  833. self,
  834. input_ids=None,
  835. attention_mask=None,
  836. encoder_hidden_states=None,
  837. encoder_attention_mask=None,
  838. inputs_embeds=None,
  839. head_mask=None,
  840. cross_attn_head_mask=None,
  841. past_key_values=None,
  842. use_cache=None,
  843. output_attentions=None,
  844. output_hidden_states=None,
  845. return_dict=None,
  846. cache_position=None,
  847. ):
  848. # Model parallel
  849. if self.model_parallel:
  850. torch.cuda.set_device(self.first_device)
  851. self.embed_tokens = self.embed_tokens.to(self.first_device)
  852. use_cache = use_cache if use_cache is not None else self.config.use_cache
  853. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  854. output_hidden_states = (
  855. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  856. )
  857. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  858. if input_ids is not None and inputs_embeds is not None:
  859. err_msg_prefix = "decoder_" if self.is_decoder else ""
  860. raise ValueError(
  861. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  862. )
  863. elif input_ids is not None:
  864. input_shape = input_ids.size()
  865. input_ids = input_ids.view(-1, input_shape[-1])
  866. elif inputs_embeds is not None:
  867. input_shape = inputs_embeds.size()[:-1]
  868. else:
  869. err_msg_prefix = "decoder_" if self.is_decoder else ""
  870. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  871. if self.gradient_checkpointing and self.training:
  872. if use_cache:
  873. logger.warning_once(
  874. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  875. )
  876. use_cache = False
  877. if inputs_embeds is None:
  878. if self.embed_tokens is None:
  879. raise ValueError("You have to initialize the model with valid token embeddings")
  880. inputs_embeds = self.embed_tokens(input_ids)
  881. batch_size, seq_length = input_shape
  882. if use_cache is True:
  883. if not self.is_decoder:
  884. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  885. # initialize past_key_values
  886. return_legacy_cache = False
  887. return_self_attention_cache = False
  888. if self.is_decoder and (use_cache or past_key_values is not None):
  889. if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
  890. return_self_attention_cache = True
  891. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
  892. elif not isinstance(past_key_values, EncoderDecoderCache):
  893. return_legacy_cache = True
  894. logger.warning_once(
  895. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
  896. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  897. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  898. )
  899. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  900. elif past_key_values is None:
  901. past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
  902. elif not self.is_decoder:
  903. # do not pass cache object down the line for encoder stack
  904. # it messes indexing later in decoder-stack because cache object is modified in-place
  905. past_key_values = None
  906. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  907. if cache_position is None:
  908. cache_position = torch.arange(
  909. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  910. )
  911. if attention_mask is None and not is_torchdynamo_compiling():
  912. # required mask seq length can be calculated via length of past cache
  913. mask_seq_length = past_key_values_length + seq_length
  914. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  915. if self.config.is_decoder:
  916. causal_mask = self._update_causal_mask(
  917. attention_mask,
  918. inputs_embeds,
  919. cache_position,
  920. past_key_values.self_attention_cache if past_key_values is not None else None,
  921. output_attentions,
  922. )
  923. elif attention_mask is not None:
  924. causal_mask = attention_mask[:, None, None, :]
  925. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  926. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  927. else:
  928. causal_mask = None
  929. # If a 2D or 3D attention mask is provided for the cross-attention
  930. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  931. if self.is_decoder and encoder_hidden_states is not None:
  932. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  933. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  934. if encoder_attention_mask is None:
  935. encoder_attention_mask = torch.ones(
  936. encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
  937. )
  938. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  939. else:
  940. encoder_extended_attention_mask = None
  941. # Prepare head mask if needed
  942. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  943. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  944. all_hidden_states = () if output_hidden_states else None
  945. all_attentions = () if output_attentions else None
  946. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  947. position_bias = None
  948. encoder_decoder_position_bias = None
  949. hidden_states = self.dropout(inputs_embeds)
  950. for i, layer_module in enumerate(self.block):
  951. layer_head_mask = head_mask[i]
  952. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  953. # Model parallel
  954. if self.model_parallel:
  955. torch.cuda.set_device(hidden_states.device)
  956. # Ensure that attention_mask is always on the same device as hidden_states
  957. if causal_mask is not None:
  958. causal_mask = causal_mask.to(hidden_states.device)
  959. if position_bias is not None:
  960. position_bias = position_bias.to(hidden_states.device)
  961. if encoder_hidden_states is not None:
  962. encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
  963. if encoder_extended_attention_mask is not None:
  964. encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
  965. if encoder_decoder_position_bias is not None:
  966. encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
  967. if layer_head_mask is not None:
  968. layer_head_mask = layer_head_mask.to(hidden_states.device)
  969. if cross_attn_layer_head_mask is not None:
  970. cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
  971. if output_hidden_states:
  972. all_hidden_states = all_hidden_states + (hidden_states,)
  973. if self.gradient_checkpointing and self.training:
  974. layer_outputs = self._gradient_checkpointing_func(
  975. layer_module.forward,
  976. hidden_states,
  977. causal_mask,
  978. position_bias,
  979. encoder_hidden_states,
  980. encoder_extended_attention_mask,
  981. encoder_decoder_position_bias,
  982. layer_head_mask,
  983. cross_attn_layer_head_mask,
  984. None, # past_key_value is always None with gradient checkpointing
  985. use_cache,
  986. output_attentions,
  987. return_dict,
  988. cache_position,
  989. )
  990. else:
  991. layer_outputs = layer_module(
  992. hidden_states,
  993. attention_mask=causal_mask,
  994. position_bias=position_bias,
  995. encoder_hidden_states=encoder_hidden_states,
  996. encoder_attention_mask=encoder_extended_attention_mask,
  997. encoder_decoder_position_bias=encoder_decoder_position_bias,
  998. layer_head_mask=layer_head_mask,
  999. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  1000. past_key_value=past_key_values,
  1001. use_cache=use_cache,
  1002. output_attentions=output_attentions,
  1003. return_dict=return_dict,
  1004. cache_position=cache_position,
  1005. )
  1006. # layer_outputs is a tuple with:
  1007. # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1008. if use_cache is False:
  1009. layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
  1010. hidden_states, next_decoder_cache = layer_outputs[:2]
  1011. # We share the position biases between the layers - the first layer store them
  1012. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1013. # (cross-attention position bias), (cross-attention weights)
  1014. position_bias = layer_outputs[2]
  1015. if self.is_decoder and encoder_hidden_states is not None:
  1016. encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
  1017. if output_attentions:
  1018. all_attentions = all_attentions + (layer_outputs[3],)
  1019. if self.is_decoder:
  1020. all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
  1021. # Model Parallel: If it's the last layer for that device, put things on the next device
  1022. if self.model_parallel:
  1023. for k, v in self.device_map.items():
  1024. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  1025. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  1026. hidden_states = self.final_layer_norm(hidden_states)
  1027. hidden_states = self.dropout(hidden_states)
  1028. # Add last layer
  1029. if output_hidden_states:
  1030. all_hidden_states = all_hidden_states + (hidden_states,)
  1031. next_cache = next_decoder_cache if use_cache else None
  1032. if return_self_attention_cache:
  1033. next_cache = past_key_values.self_attention_cache
  1034. if return_legacy_cache:
  1035. next_cache = past_key_values.to_legacy_cache()
  1036. if not return_dict:
  1037. return tuple(
  1038. v
  1039. for v in [
  1040. hidden_states,
  1041. next_cache,
  1042. all_hidden_states,
  1043. all_attentions,
  1044. all_cross_attentions,
  1045. ]
  1046. if v is not None
  1047. )
  1048. return BaseModelOutputWithPastAndCrossAttentions(
  1049. last_hidden_state=hidden_states,
  1050. past_key_values=next_cache,
  1051. hidden_states=all_hidden_states,
  1052. attentions=all_attentions,
  1053. cross_attentions=all_cross_attentions,
  1054. )
  1055. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  1056. def _update_causal_mask(
  1057. self,
  1058. attention_mask: torch.Tensor,
  1059. input_tensor: torch.Tensor,
  1060. cache_position: torch.Tensor,
  1061. past_key_values: Cache,
  1062. output_attentions: bool,
  1063. ):
  1064. if self.config._attn_implementation == "flash_attention_2":
  1065. if attention_mask is not None and 0.0 in attention_mask:
  1066. return attention_mask
  1067. return None
  1068. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1069. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1070. # to infer the attention mask.
  1071. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1072. using_static_cache = isinstance(past_key_values, StaticCache)
  1073. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1074. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  1075. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1076. attention_mask,
  1077. inputs_embeds=input_tensor,
  1078. past_key_values_length=past_seen_tokens,
  1079. is_training=self.training,
  1080. ):
  1081. return None
  1082. dtype, device = input_tensor.dtype, input_tensor.device
  1083. sequence_length = input_tensor.shape[1]
  1084. if using_static_cache:
  1085. target_length = past_key_values.get_max_cache_shape()
  1086. else:
  1087. target_length = (
  1088. attention_mask.shape[-1]
  1089. if isinstance(attention_mask, torch.Tensor)
  1090. else past_seen_tokens + sequence_length + 1
  1091. )
  1092. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1093. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1094. attention_mask,
  1095. sequence_length=sequence_length,
  1096. target_length=target_length,
  1097. dtype=dtype,
  1098. device=device,
  1099. cache_position=cache_position,
  1100. batch_size=input_tensor.shape[0],
  1101. )
  1102. if (
  1103. self.config._attn_implementation == "sdpa"
  1104. and attention_mask is not None
  1105. and attention_mask.device.type == "cuda"
  1106. and not output_attentions
  1107. ):
  1108. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1109. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1110. # Details: https://github.com/pytorch/pytorch/issues/110213
  1111. min_dtype = torch.finfo(dtype).min
  1112. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1113. return causal_mask
  1114. @staticmethod
  1115. # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
  1116. def _prepare_4d_causal_attention_mask_with_cache_position(
  1117. attention_mask: torch.Tensor,
  1118. sequence_length: int,
  1119. target_length: int,
  1120. dtype: torch.dtype,
  1121. device: torch.device,
  1122. cache_position: torch.Tensor,
  1123. batch_size: int,
  1124. **kwargs,
  1125. ):
  1126. """
  1127. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1128. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1129. Args:
  1130. attention_mask (`torch.Tensor`):
  1131. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1132. `(batch_size, 1, query_length, key_value_length)`.
  1133. sequence_length (`int`):
  1134. The sequence length being processed.
  1135. target_length (`int`):
  1136. The target length: when generating with static cache, the mask should be as long as the static cache,
  1137. to account for the 0 padding, the part of the cache that is not filled yet.
  1138. dtype (`torch.dtype`):
  1139. The dtype to use for the 4D attention mask.
  1140. device (`torch.device`):
  1141. The device to plcae the 4D attention mask on.
  1142. cache_position (`torch.Tensor`):
  1143. Indices depicting the position of the input sequence tokens in the sequence.
  1144. batch_size (`torch.Tensor`):
  1145. Batch size.
  1146. """
  1147. if attention_mask is not None and attention_mask.dim() == 4:
  1148. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1149. causal_mask = attention_mask
  1150. else:
  1151. min_dtype = torch.finfo(dtype).min
  1152. causal_mask = torch.full(
  1153. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1154. )
  1155. if sequence_length != 1:
  1156. causal_mask = torch.triu(causal_mask, diagonal=1)
  1157. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1158. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1159. if attention_mask is not None:
  1160. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1161. mask_length = attention_mask.shape[-1]
  1162. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1163. padding_mask = padding_mask == 0
  1164. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1165. padding_mask, min_dtype
  1166. )
  1167. return causal_mask
  1168. T5_START_DOCSTRING = r"""
  1169. The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text
  1170. Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan
  1171. Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a
  1172. text-to-text denoising generative setting.
  1173. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  1174. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1175. etc.)
  1176. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  1177. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  1178. and behavior.
  1179. Parameters:
  1180. config ([`T5Config`]): Model configuration class with all the parameters of the model.
  1181. Initializing with a config file does not load the weights associated with the model, only the
  1182. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1183. """
  1184. T5_INPUTS_DOCSTRING = r"""
  1185. Args:
  1186. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1187. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1188. should be able to pad the inputs on both the right and the left.
  1189. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1190. [`PreTrainedTokenizer.__call__`] for detail.
  1191. [What are input IDs?](../glossary#input-ids)
  1192. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1193. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1194. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1195. - 1 for tokens that are **not masked**,
  1196. - 0 for tokens that are **masked**.
  1197. [What are attention masks?](../glossary#attention-mask)
  1198. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1199. Indices of decoder input sequence tokens in the vocabulary.
  1200. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1201. [`PreTrainedTokenizer.__call__`] for details.
  1202. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1203. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1204. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1205. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1206. Training](./t5#training).
  1207. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1208. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1209. be used by default.
  1210. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1211. Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
  1212. 1]`:
  1213. - 1 indicates the head is **not masked**,
  1214. - 0 indicates the head is **masked**.
  1215. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1216. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1217. 1]`:
  1218. - 1 indicates the head is **not masked**,
  1219. - 0 indicates the head is **masked**.
  1220. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1221. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1222. `[0, 1]`:
  1223. - 1 indicates the head is **not masked**,
  1224. - 0 indicates the head is **masked**.
  1225. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  1226. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
  1227. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
  1228. the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1229. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1230. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  1231. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1232. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1233. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1234. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1235. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1236. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1237. model's internal embedding lookup matrix.
  1238. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  1239. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  1240. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  1241. input (see `past_key_values`). This is useful if you want more control over how to convert
  1242. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  1243. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  1244. of `inputs_embeds`.
  1245. use_cache (`bool`, *optional*):
  1246. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1247. `past_key_values`).
  1248. output_attentions (`bool`, *optional*):
  1249. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1250. tensors for more detail.
  1251. output_hidden_states (`bool`, *optional*):
  1252. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1253. more detail.
  1254. return_dict (`bool`, *optional*):
  1255. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1256. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  1257. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  1258. cache in the correct position and to infer the complete sequence length.
  1259. """
  1260. T5_ENCODER_INPUTS_DOCSTRING = r"""
  1261. Args:
  1262. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1263. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1264. should be able to pad the inputs on both the right and the left.
  1265. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1266. [`PreTrainedTokenizer.__call__`] for detail.
  1267. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1268. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1269. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1270. - 1 for tokens that are **not masked**,
  1271. - 0 for tokens that are **masked**.
  1272. [What are attention masks?](../glossary#attention-mask)
  1273. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1274. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  1275. - 1 indicates the head is **not masked**,
  1276. - 0 indicates the head is **masked**.
  1277. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1278. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1279. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1280. model's internal embedding lookup matrix.
  1281. output_attentions (`bool`, *optional*):
  1282. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1283. tensors for more detail.
  1284. output_hidden_states (`bool`, *optional*):
  1285. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1286. more detail.
  1287. return_dict (`bool`, *optional*):
  1288. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1289. """
  1290. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1291. __HEAD_MASK_WARNING_MSG = """
  1292. The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
  1293. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
  1294. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
  1295. num_heads)`.
  1296. """
  1297. @add_start_docstrings(
  1298. "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.",
  1299. T5_START_DOCSTRING,
  1300. )
  1301. class T5Model(T5PreTrainedModel):
  1302. _keys_to_ignore_on_load_unexpected = [
  1303. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1304. ]
  1305. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1306. def __init__(self, config: T5Config):
  1307. super().__init__(config)
  1308. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1309. encoder_config = copy.deepcopy(config)
  1310. encoder_config.is_decoder = False
  1311. encoder_config.use_cache = False
  1312. encoder_config.is_encoder_decoder = False
  1313. self.encoder = T5Stack(encoder_config, self.shared)
  1314. decoder_config = copy.deepcopy(config)
  1315. decoder_config.is_decoder = True
  1316. decoder_config.is_encoder_decoder = False
  1317. decoder_config.num_layers = config.num_decoder_layers
  1318. self.decoder = T5Stack(decoder_config, self.shared)
  1319. # Initialize weights and apply final processing
  1320. self.post_init()
  1321. # Model parallel
  1322. self.model_parallel = False
  1323. self.device_map = None
  1324. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1325. def parallelize(self, device_map=None):
  1326. warnings.warn(
  1327. "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  1328. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1329. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
  1330. " 0, 'encoder.block.1': 1, ...}",
  1331. FutureWarning,
  1332. )
  1333. self.device_map = (
  1334. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1335. if device_map is None
  1336. else device_map
  1337. )
  1338. assert_device_map(self.device_map, len(self.encoder.block))
  1339. self.encoder.parallelize(self.device_map)
  1340. self.decoder.parallelize(self.device_map)
  1341. self.model_parallel = True
  1342. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1343. def deparallelize(self):
  1344. warnings.warn(
  1345. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1346. FutureWarning,
  1347. )
  1348. self.encoder.deparallelize()
  1349. self.decoder.deparallelize()
  1350. self.encoder = self.encoder.to("cpu")
  1351. self.decoder = self.decoder.to("cpu")
  1352. self.model_parallel = False
  1353. self.device_map = None
  1354. torch.cuda.empty_cache()
  1355. def get_input_embeddings(self):
  1356. return self.shared
  1357. def set_input_embeddings(self, new_embeddings):
  1358. self.shared = new_embeddings
  1359. self.encoder.set_input_embeddings(new_embeddings)
  1360. self.decoder.set_input_embeddings(new_embeddings)
  1361. def _tie_weights(self):
  1362. if self.config.tie_word_embeddings:
  1363. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1364. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1365. def get_encoder(self):
  1366. return self.encoder
  1367. def get_decoder(self):
  1368. return self.decoder
  1369. def _prune_heads(self, heads_to_prune):
  1370. """
  1371. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1372. class PreTrainedModel
  1373. """
  1374. for layer, heads in heads_to_prune.items():
  1375. self.encoder.layer[layer].attention.prune_heads(heads)
  1376. @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  1377. @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
  1378. def forward(
  1379. self,
  1380. input_ids: Optional[torch.LongTensor] = None,
  1381. attention_mask: Optional[torch.FloatTensor] = None,
  1382. decoder_input_ids: Optional[torch.LongTensor] = None,
  1383. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1384. head_mask: Optional[torch.FloatTensor] = None,
  1385. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1386. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1387. encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1388. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1389. inputs_embeds: Optional[torch.Tensor] = None,
  1390. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1391. use_cache: Optional[bool] = None,
  1392. output_attentions: Optional[bool] = None,
  1393. output_hidden_states: Optional[bool] = None,
  1394. return_dict: Optional[bool] = None,
  1395. cache_position: Optional[torch.LongTensor] = None,
  1396. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1397. r"""
  1398. Returns:
  1399. Example:
  1400. ```python
  1401. >>> from transformers import AutoTokenizer, T5Model
  1402. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1403. >>> model = T5Model.from_pretrained("google-t5/t5-small")
  1404. >>> input_ids = tokenizer(
  1405. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1406. ... ).input_ids # Batch size 1
  1407. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1408. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
  1409. >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
  1410. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  1411. >>> # forward pass
  1412. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1413. >>> last_hidden_states = outputs.last_hidden_state
  1414. ```"""
  1415. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1416. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1417. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1418. if head_mask is not None and decoder_head_mask is None:
  1419. if self.config.num_layers == self.config.num_decoder_layers:
  1420. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1421. decoder_head_mask = head_mask
  1422. # Encode if needed (training, first prediction pass)
  1423. if encoder_outputs is None:
  1424. encoder_outputs = self.encoder(
  1425. input_ids=input_ids,
  1426. attention_mask=attention_mask,
  1427. inputs_embeds=inputs_embeds,
  1428. head_mask=head_mask,
  1429. output_attentions=output_attentions,
  1430. output_hidden_states=output_hidden_states,
  1431. return_dict=return_dict,
  1432. )
  1433. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1434. encoder_outputs = BaseModelOutput(
  1435. last_hidden_state=encoder_outputs[0],
  1436. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1437. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1438. )
  1439. hidden_states = encoder_outputs[0]
  1440. # Set device for model parallelism
  1441. if self.model_parallel:
  1442. torch.cuda.set_device(self.decoder.first_device)
  1443. hidden_states = hidden_states.to(self.decoder.first_device)
  1444. if decoder_input_ids is not None:
  1445. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1446. if attention_mask is not None:
  1447. attention_mask = attention_mask.to(self.decoder.first_device)
  1448. if decoder_attention_mask is not None:
  1449. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1450. # Decode
  1451. decoder_outputs = self.decoder(
  1452. input_ids=decoder_input_ids,
  1453. attention_mask=decoder_attention_mask,
  1454. inputs_embeds=decoder_inputs_embeds,
  1455. past_key_values=past_key_values,
  1456. encoder_hidden_states=hidden_states,
  1457. encoder_attention_mask=attention_mask,
  1458. head_mask=decoder_head_mask,
  1459. cross_attn_head_mask=cross_attn_head_mask,
  1460. use_cache=use_cache,
  1461. output_attentions=output_attentions,
  1462. output_hidden_states=output_hidden_states,
  1463. return_dict=return_dict,
  1464. cache_position=cache_position,
  1465. )
  1466. if not return_dict:
  1467. return decoder_outputs + encoder_outputs
  1468. return Seq2SeqModelOutput(
  1469. last_hidden_state=decoder_outputs.last_hidden_state,
  1470. past_key_values=decoder_outputs.past_key_values,
  1471. decoder_hidden_states=decoder_outputs.hidden_states,
  1472. decoder_attentions=decoder_outputs.attentions,
  1473. cross_attentions=decoder_outputs.cross_attentions,
  1474. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1475. encoder_hidden_states=encoder_outputs.hidden_states,
  1476. encoder_attentions=encoder_outputs.attentions,
  1477. )
  1478. @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING)
  1479. class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
  1480. _keys_to_ignore_on_load_unexpected = [
  1481. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1482. ]
  1483. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1484. def __init__(self, config: T5Config):
  1485. super().__init__(config)
  1486. self.model_dim = config.d_model
  1487. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1488. encoder_config = copy.deepcopy(config)
  1489. encoder_config.is_decoder = False
  1490. encoder_config.use_cache = False
  1491. encoder_config.is_encoder_decoder = False
  1492. self.encoder = T5Stack(encoder_config, self.shared)
  1493. decoder_config = copy.deepcopy(config)
  1494. decoder_config.is_decoder = True
  1495. decoder_config.is_encoder_decoder = False
  1496. decoder_config.num_layers = config.num_decoder_layers
  1497. self.decoder = T5Stack(decoder_config, self.shared)
  1498. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1499. # Initialize weights and apply final processing
  1500. self.post_init()
  1501. # Model parallel
  1502. self.model_parallel = False
  1503. self.device_map = None
  1504. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1505. def parallelize(self, device_map=None):
  1506. warnings.warn(
  1507. "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
  1508. " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
  1509. " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
  1510. " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
  1511. FutureWarning,
  1512. )
  1513. self.device_map = (
  1514. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1515. if device_map is None
  1516. else device_map
  1517. )
  1518. assert_device_map(self.device_map, len(self.encoder.block))
  1519. self.encoder.parallelize(self.device_map)
  1520. self.decoder.parallelize(self.device_map)
  1521. self.lm_head = self.lm_head.to(self.decoder.first_device)
  1522. self.model_parallel = True
  1523. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1524. def deparallelize(self):
  1525. warnings.warn(
  1526. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1527. FutureWarning,
  1528. )
  1529. self.encoder.deparallelize()
  1530. self.decoder.deparallelize()
  1531. self.encoder = self.encoder.to("cpu")
  1532. self.decoder = self.decoder.to("cpu")
  1533. self.lm_head = self.lm_head.to("cpu")
  1534. self.model_parallel = False
  1535. self.device_map = None
  1536. torch.cuda.empty_cache()
  1537. def get_input_embeddings(self):
  1538. return self.shared
  1539. def set_input_embeddings(self, new_embeddings):
  1540. self.shared = new_embeddings
  1541. self.encoder.set_input_embeddings(new_embeddings)
  1542. self.decoder.set_input_embeddings(new_embeddings)
  1543. def _tie_weights(self):
  1544. if self.config.tie_word_embeddings:
  1545. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1546. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1547. def set_output_embeddings(self, new_embeddings):
  1548. self.lm_head = new_embeddings
  1549. def get_output_embeddings(self):
  1550. return self.lm_head
  1551. def get_encoder(self):
  1552. return self.encoder
  1553. def get_decoder(self):
  1554. return self.decoder
  1555. @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  1556. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1557. def forward(
  1558. self,
  1559. input_ids: Optional[torch.LongTensor] = None,
  1560. attention_mask: Optional[torch.FloatTensor] = None,
  1561. decoder_input_ids: Optional[torch.LongTensor] = None,
  1562. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1563. head_mask: Optional[torch.FloatTensor] = None,
  1564. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1565. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1566. encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1567. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1568. inputs_embeds: Optional[torch.FloatTensor] = None,
  1569. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1570. labels: Optional[torch.LongTensor] = None,
  1571. use_cache: Optional[bool] = None,
  1572. output_attentions: Optional[bool] = None,
  1573. output_hidden_states: Optional[bool] = None,
  1574. return_dict: Optional[bool] = None,
  1575. cache_position: Optional[torch.LongTensor] = None,
  1576. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1577. r"""
  1578. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1579. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1580. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1581. labels in `[0, ..., config.vocab_size]`
  1582. Returns:
  1583. Examples:
  1584. ```python
  1585. >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
  1586. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1587. >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
  1588. >>> # training
  1589. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  1590. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  1591. >>> outputs = model(input_ids=input_ids, labels=labels)
  1592. >>> loss = outputs.loss
  1593. >>> logits = outputs.logits
  1594. >>> # inference
  1595. >>> input_ids = tokenizer(
  1596. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  1597. ... ).input_ids # Batch size 1
  1598. >>> outputs = model.generate(input_ids)
  1599. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1600. >>> # studies have shown that owning a dog is good for you.
  1601. ```"""
  1602. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1603. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1604. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1605. if head_mask is not None and decoder_head_mask is None:
  1606. if self.config.num_layers == self.config.num_decoder_layers:
  1607. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1608. decoder_head_mask = head_mask
  1609. # Encode if needed (training, first prediction pass)
  1610. if encoder_outputs is None:
  1611. # Convert encoder inputs in embeddings if needed
  1612. encoder_outputs = self.encoder(
  1613. input_ids=input_ids,
  1614. attention_mask=attention_mask,
  1615. inputs_embeds=inputs_embeds,
  1616. head_mask=head_mask,
  1617. output_attentions=output_attentions,
  1618. output_hidden_states=output_hidden_states,
  1619. return_dict=return_dict,
  1620. )
  1621. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1622. encoder_outputs = BaseModelOutput(
  1623. last_hidden_state=encoder_outputs[0],
  1624. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1625. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1626. )
  1627. hidden_states = encoder_outputs[0]
  1628. if self.model_parallel:
  1629. torch.cuda.set_device(self.decoder.first_device)
  1630. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1631. # get decoder inputs from shifting lm labels to the right
  1632. decoder_input_ids = self._shift_right(labels)
  1633. # Set device for model parallelism
  1634. if self.model_parallel:
  1635. torch.cuda.set_device(self.decoder.first_device)
  1636. hidden_states = hidden_states.to(self.decoder.first_device)
  1637. if decoder_input_ids is not None:
  1638. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1639. if attention_mask is not None:
  1640. attention_mask = attention_mask.to(self.decoder.first_device)
  1641. if decoder_attention_mask is not None:
  1642. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1643. # Decode
  1644. decoder_outputs = self.decoder(
  1645. input_ids=decoder_input_ids,
  1646. attention_mask=decoder_attention_mask,
  1647. inputs_embeds=decoder_inputs_embeds,
  1648. past_key_values=past_key_values,
  1649. encoder_hidden_states=hidden_states,
  1650. encoder_attention_mask=attention_mask,
  1651. head_mask=decoder_head_mask,
  1652. cross_attn_head_mask=cross_attn_head_mask,
  1653. use_cache=use_cache,
  1654. output_attentions=output_attentions,
  1655. output_hidden_states=output_hidden_states,
  1656. return_dict=return_dict,
  1657. cache_position=cache_position,
  1658. )
  1659. sequence_output = decoder_outputs[0]
  1660. # Set device for model parallelism
  1661. if self.model_parallel:
  1662. torch.cuda.set_device(self.encoder.first_device)
  1663. self.lm_head = self.lm_head.to(self.encoder.first_device)
  1664. sequence_output = sequence_output.to(self.lm_head.weight.device)
  1665. if self.config.tie_word_embeddings:
  1666. # Rescale output before projecting on vocab
  1667. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1668. sequence_output = sequence_output * (self.model_dim**-0.5)
  1669. lm_logits = self.lm_head(sequence_output)
  1670. loss = None
  1671. if labels is not None:
  1672. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1673. # move labels to correct device to enable PP
  1674. labels = labels.to(lm_logits.device)
  1675. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1676. # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  1677. if not return_dict:
  1678. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1679. return ((loss,) + output) if loss is not None else output
  1680. return Seq2SeqLMOutput(
  1681. loss=loss,
  1682. logits=lm_logits,
  1683. past_key_values=decoder_outputs.past_key_values,
  1684. decoder_hidden_states=decoder_outputs.hidden_states,
  1685. decoder_attentions=decoder_outputs.attentions,
  1686. cross_attentions=decoder_outputs.cross_attentions,
  1687. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1688. encoder_hidden_states=encoder_outputs.hidden_states,
  1689. encoder_attentions=encoder_outputs.attentions,
  1690. )
  1691. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1692. return self._shift_right(labels)
  1693. def _reorder_cache(self, past_key_values, beam_idx):
  1694. # if decoder past is not included in output
  1695. # speedy decoding is disabled and no need to reorder
  1696. if past_key_values is None:
  1697. logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
  1698. return past_key_values
  1699. reordered_decoder_past = ()
  1700. for layer_past_states in past_key_values:
  1701. # get the correct batch idx from layer past batch dim
  1702. # batch dim of `past` is at 2nd position
  1703. reordered_layer_past_states = ()
  1704. for layer_past_state in layer_past_states:
  1705. # need to set correct `past` for each of the four key / value states
  1706. reordered_layer_past_states = reordered_layer_past_states + (
  1707. layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
  1708. )
  1709. if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
  1710. raise ValueError(
  1711. f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
  1712. )
  1713. if len(reordered_layer_past_states) != len(layer_past_states):
  1714. raise ValueError(
  1715. f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
  1716. )
  1717. reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
  1718. return reordered_decoder_past
  1719. @add_start_docstrings(
  1720. "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.",
  1721. T5_START_DOCSTRING,
  1722. )
  1723. class T5EncoderModel(T5PreTrainedModel):
  1724. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1725. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  1726. def __init__(self, config: T5Config):
  1727. super().__init__(config)
  1728. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1729. encoder_config = copy.deepcopy(config)
  1730. encoder_config.use_cache = False
  1731. encoder_config.is_encoder_decoder = False
  1732. self.encoder = T5Stack(encoder_config, self.shared)
  1733. # Initialize weights and apply final processing
  1734. self.post_init()
  1735. # Model parallel
  1736. self.model_parallel = False
  1737. self.device_map = None
  1738. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1739. def parallelize(self, device_map=None):
  1740. warnings.warn(
  1741. "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
  1742. " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1743. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  1744. " 'block.1': 1, ...}",
  1745. FutureWarning,
  1746. )
  1747. self.device_map = (
  1748. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1749. if device_map is None
  1750. else device_map
  1751. )
  1752. assert_device_map(self.device_map, len(self.encoder.block))
  1753. self.encoder.parallelize(self.device_map)
  1754. self.model_parallel = True
  1755. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1756. def deparallelize(self):
  1757. warnings.warn(
  1758. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1759. FutureWarning,
  1760. )
  1761. self.encoder.deparallelize()
  1762. self.encoder = self.encoder.to("cpu")
  1763. self.model_parallel = False
  1764. self.device_map = None
  1765. torch.cuda.empty_cache()
  1766. def get_input_embeddings(self):
  1767. return self.shared
  1768. def set_input_embeddings(self, new_embeddings):
  1769. self.shared = new_embeddings
  1770. self.encoder.set_input_embeddings(new_embeddings)
  1771. def _tie_weights(self):
  1772. if self.config.tie_word_embeddings:
  1773. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1774. def get_encoder(self):
  1775. return self.encoder
  1776. def _prune_heads(self, heads_to_prune):
  1777. """
  1778. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1779. class PreTrainedModel
  1780. """
  1781. for layer, heads in heads_to_prune.items():
  1782. self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
  1783. @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING)
  1784. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  1785. def forward(
  1786. self,
  1787. input_ids: Optional[torch.LongTensor] = None,
  1788. attention_mask: Optional[torch.FloatTensor] = None,
  1789. head_mask: Optional[torch.FloatTensor] = None,
  1790. inputs_embeds: Optional[torch.FloatTensor] = None,
  1791. output_attentions: Optional[bool] = None,
  1792. output_hidden_states: Optional[bool] = None,
  1793. return_dict: Optional[bool] = None,
  1794. ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
  1795. r"""
  1796. Returns:
  1797. Example:
  1798. ```python
  1799. >>> from transformers import AutoTokenizer, T5EncoderModel
  1800. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1801. >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
  1802. >>> input_ids = tokenizer(
  1803. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1804. ... ).input_ids # Batch size 1
  1805. >>> outputs = model(input_ids=input_ids)
  1806. >>> last_hidden_states = outputs.last_hidden_state
  1807. ```"""
  1808. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1809. encoder_outputs = self.encoder(
  1810. input_ids=input_ids,
  1811. attention_mask=attention_mask,
  1812. inputs_embeds=inputs_embeds,
  1813. head_mask=head_mask,
  1814. output_attentions=output_attentions,
  1815. output_hidden_states=output_hidden_states,
  1816. return_dict=return_dict,
  1817. )
  1818. return encoder_outputs
  1819. @add_start_docstrings(
  1820. """
  1821. T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1822. tasks.
  1823. """,
  1824. T5_START_DOCSTRING,
  1825. )
  1826. class T5ForSequenceClassification(T5PreTrainedModel):
  1827. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1828. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1829. def __init__(self, config: T5Config):
  1830. super().__init__(config)
  1831. self.transformer = T5Model(config)
  1832. self.classification_head = T5ClassificationHead(config)
  1833. # Initialize weights and apply final processing
  1834. self.post_init()
  1835. self.model_parallel = False
  1836. @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  1837. @replace_return_docstrings(output_type=Seq2SeqSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
  1838. def forward(
  1839. self,
  1840. input_ids: torch.LongTensor = None,
  1841. attention_mask: Optional[torch.Tensor] = None,
  1842. decoder_input_ids: Optional[torch.LongTensor] = None,
  1843. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1844. head_mask: Optional[torch.Tensor] = None,
  1845. decoder_head_mask: Optional[torch.Tensor] = None,
  1846. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1847. encoder_outputs: Optional[List[torch.FloatTensor]] = None,
  1848. inputs_embeds: Optional[torch.FloatTensor] = None,
  1849. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1850. labels: Optional[torch.LongTensor] = None,
  1851. use_cache: Optional[bool] = None,
  1852. output_attentions: Optional[bool] = None,
  1853. output_hidden_states: Optional[bool] = None,
  1854. return_dict: Optional[bool] = None,
  1855. ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]:
  1856. r"""
  1857. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1858. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1859. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1860. Returns:
  1861. """
  1862. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1863. if labels is not None:
  1864. use_cache = False
  1865. if input_ids is None and inputs_embeds is not None:
  1866. raise NotImplementedError(
  1867. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1868. )
  1869. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
  1870. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1871. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1872. if input_ids is None:
  1873. raise ValueError(
  1874. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1875. "passed, `input_ids` cannot be `None`. Please pass either "
  1876. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1877. )
  1878. decoder_input_ids = self._shift_right(input_ids)
  1879. outputs = self.transformer(
  1880. input_ids,
  1881. attention_mask=attention_mask,
  1882. decoder_input_ids=decoder_input_ids,
  1883. decoder_attention_mask=decoder_attention_mask,
  1884. head_mask=head_mask,
  1885. decoder_head_mask=decoder_head_mask,
  1886. cross_attn_head_mask=cross_attn_head_mask,
  1887. encoder_outputs=encoder_outputs,
  1888. inputs_embeds=inputs_embeds,
  1889. decoder_inputs_embeds=decoder_inputs_embeds,
  1890. use_cache=use_cache,
  1891. output_attentions=output_attentions,
  1892. output_hidden_states=output_hidden_states,
  1893. return_dict=return_dict,
  1894. )
  1895. sequence_output = outputs[0]
  1896. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1897. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1898. raise ValueError("All examples must have the same number of <eos> tokens.")
  1899. batch_size, _, hidden_size = sequence_output.shape
  1900. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1901. logits = self.classification_head(sentence_representation)
  1902. loss = None
  1903. if labels is not None:
  1904. labels = labels.to(logits.device)
  1905. if self.config.problem_type is None:
  1906. if self.config.num_labels == 1:
  1907. self.config.problem_type = "regression"
  1908. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1909. self.config.problem_type = "single_label_classification"
  1910. else:
  1911. self.config.problem_type = "multi_label_classification"
  1912. if self.config.problem_type == "regression":
  1913. loss_fct = MSELoss()
  1914. if self.config.num_labels == 1:
  1915. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1916. else:
  1917. loss = loss_fct(logits, labels)
  1918. elif self.config.problem_type == "single_label_classification":
  1919. loss_fct = CrossEntropyLoss()
  1920. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1921. elif self.config.problem_type == "multi_label_classification":
  1922. loss_fct = BCEWithLogitsLoss()
  1923. loss = loss_fct(logits, labels)
  1924. if not return_dict:
  1925. output = (logits,) + outputs[1:]
  1926. return ((loss,) + output) if loss is not None else output
  1927. return Seq2SeqSequenceClassifierOutput(
  1928. loss=loss,
  1929. logits=logits,
  1930. past_key_values=outputs.past_key_values,
  1931. decoder_hidden_states=outputs.decoder_hidden_states,
  1932. decoder_attentions=outputs.decoder_attentions,
  1933. cross_attentions=outputs.cross_attentions,
  1934. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1935. encoder_hidden_states=outputs.encoder_hidden_states,
  1936. encoder_attentions=outputs.encoder_attentions,
  1937. )
  1938. @add_start_docstrings(
  1939. """
  1940. T5 Encoder Model with a token classification head on top (a linear layer on top of the hidden-states output)
  1941. e.g. for Named-Entity-Recognition (NER) tasks.
  1942. """,
  1943. T5_START_DOCSTRING,
  1944. )
  1945. class T5ForTokenClassification(T5PreTrainedModel):
  1946. _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
  1947. def __init__(self, config: T5Config):
  1948. super().__init__(config)
  1949. self.num_labels = config.num_labels
  1950. self.transformer = T5EncoderModel(config)
  1951. self.dropout = nn.Dropout(config.classifier_dropout)
  1952. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1953. # Initialize weights and apply final processing
  1954. self.post_init()
  1955. @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  1956. @replace_return_docstrings(output_type=TokenClassifierOutput, config_class=_CONFIG_FOR_DOC)
  1957. def forward(
  1958. self,
  1959. input_ids: Optional[torch.Tensor] = None,
  1960. attention_mask: Optional[torch.Tensor] = None,
  1961. head_mask: Optional[torch.Tensor] = None,
  1962. inputs_embeds: Optional[torch.Tensor] = None,
  1963. labels: Optional[torch.Tensor] = None,
  1964. output_attentions: Optional[bool] = None,
  1965. output_hidden_states: Optional[bool] = None,
  1966. return_dict: Optional[bool] = None,
  1967. ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
  1968. r"""
  1969. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1970. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1971. Returns:
  1972. """
  1973. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1974. outputs = self.transformer(
  1975. input_ids,
  1976. attention_mask=attention_mask,
  1977. head_mask=head_mask,
  1978. inputs_embeds=inputs_embeds,
  1979. output_attentions=output_attentions,
  1980. output_hidden_states=output_hidden_states,
  1981. return_dict=return_dict,
  1982. )
  1983. hidden_states = outputs[0]
  1984. hidden_states = self.dropout(hidden_states)
  1985. logits = self.classifier(hidden_states)
  1986. loss = None
  1987. if labels is not None:
  1988. loss_fct = CrossEntropyLoss()
  1989. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1990. if not return_dict:
  1991. output = (logits, outputs[2:-1])
  1992. return ((loss,) + output) if loss is not None else output
  1993. return TokenClassifierOutput(
  1994. loss=loss,
  1995. logits=logits,
  1996. hidden_states=outputs.hidden_states,
  1997. attentions=outputs.attentions,
  1998. )
  1999. @add_start_docstrings(
  2000. """
  2001. T5 Model with a span classification head on top for extractive question-answering tasks like SQuAD (linear layers
  2002. on top of the hidden-states output to compute `span start logits` and `span end logits`).
  2003. """,
  2004. T5_START_DOCSTRING,
  2005. )
  2006. class T5ForQuestionAnswering(T5PreTrainedModel):
  2007. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  2008. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  2009. def __init__(self, config: T5Config):
  2010. super().__init__(config)
  2011. self.model_dim = config.d_model
  2012. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  2013. encoder_config = copy.deepcopy(config)
  2014. encoder_config.is_decoder = False
  2015. encoder_config.use_cache = False
  2016. encoder_config.is_encoder_decoder = False
  2017. self.encoder = T5Stack(encoder_config, self.shared)
  2018. decoder_config = copy.deepcopy(config)
  2019. decoder_config.is_decoder = True
  2020. decoder_config.is_encoder_decoder = False
  2021. decoder_config.num_layers = config.num_decoder_layers
  2022. self.decoder = T5Stack(decoder_config, self.shared)
  2023. self.num_labels = config.num_labels
  2024. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  2025. # Initialize weights and apply final processing
  2026. self.post_init()
  2027. self.model_parallel = False
  2028. def get_input_embeddings(self):
  2029. return self.shared
  2030. def set_input_embeddings(self, new_embeddings):
  2031. self.shared = new_embeddings
  2032. self.encoder.set_input_embeddings(new_embeddings)
  2033. self.decoder.set_input_embeddings(new_embeddings)
  2034. def _tie_weights(self):
  2035. if self.config.tie_word_embeddings:
  2036. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  2037. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  2038. def get_encoder(self):
  2039. return self.encoder
  2040. def get_decoder(self):
  2041. return self.decoder
  2042. @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
  2043. @replace_return_docstrings(output_type=Seq2SeqQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
  2044. def forward(
  2045. self,
  2046. input_ids: Optional[torch.LongTensor] = None,
  2047. attention_mask: Optional[torch.FloatTensor] = None,
  2048. decoder_input_ids: Optional[torch.LongTensor] = None,
  2049. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  2050. head_mask: Optional[torch.FloatTensor] = None,
  2051. decoder_head_mask: Optional[torch.FloatTensor] = None,
  2052. cross_attn_head_mask: Optional[torch.Tensor] = None,
  2053. encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  2054. start_positions: Optional[torch.LongTensor] = None,
  2055. end_positions: Optional[torch.LongTensor] = None,
  2056. inputs_embeds: Optional[torch.FloatTensor] = None,
  2057. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  2058. use_cache: Optional[bool] = None,
  2059. output_attentions: Optional[bool] = None,
  2060. output_hidden_states: Optional[bool] = None,
  2061. return_dict: Optional[bool] = None,
  2062. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
  2063. r"""
  2064. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  2065. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  2066. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  2067. are not taken into account for computing the loss.
  2068. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  2069. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  2070. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  2071. are not taken into account for computing the loss.
  2072. Returns:
  2073. """
  2074. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2075. use_cache = use_cache if use_cache is not None else self.config.use_cache
  2076. if start_positions is not None and end_positions is not None:
  2077. use_cache = False
  2078. # Copied from models.bart.modeling_bart.BartModel.forward
  2079. # different to other models, T5 automatically creates decoder_input_ids from
  2080. # input_ids if no decoder_input_ids are provided
  2081. if decoder_input_ids is None and decoder_inputs_embeds is None:
  2082. if input_ids is None:
  2083. raise ValueError(
  2084. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  2085. "passed, `input_ids` cannot be `None`. Please pass either "
  2086. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  2087. )
  2088. decoder_input_ids = self._shift_right(input_ids)
  2089. use_cache = use_cache if use_cache is not None else self.config.use_cache
  2090. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2091. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  2092. if head_mask is not None and decoder_head_mask is None:
  2093. if self.config.num_layers == self.config.num_decoder_layers:
  2094. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  2095. decoder_head_mask = head_mask
  2096. # Encode if needed (training, first prediction pass)
  2097. if encoder_outputs is None:
  2098. encoder_outputs = self.encoder(
  2099. input_ids=input_ids,
  2100. attention_mask=attention_mask,
  2101. inputs_embeds=inputs_embeds,
  2102. head_mask=head_mask,
  2103. output_attentions=output_attentions,
  2104. output_hidden_states=output_hidden_states,
  2105. return_dict=return_dict,
  2106. )
  2107. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  2108. encoder_outputs = BaseModelOutput(
  2109. last_hidden_state=encoder_outputs[0],
  2110. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  2111. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  2112. )
  2113. hidden_states = encoder_outputs[0]
  2114. # Decode
  2115. decoder_outputs = self.decoder(
  2116. input_ids=decoder_input_ids,
  2117. attention_mask=decoder_attention_mask,
  2118. inputs_embeds=decoder_inputs_embeds,
  2119. past_key_values=None,
  2120. encoder_hidden_states=hidden_states,
  2121. encoder_attention_mask=attention_mask,
  2122. head_mask=decoder_head_mask,
  2123. cross_attn_head_mask=cross_attn_head_mask,
  2124. use_cache=use_cache,
  2125. output_attentions=output_attentions,
  2126. output_hidden_states=output_hidden_states,
  2127. return_dict=return_dict,
  2128. )
  2129. sequence_output = decoder_outputs[0]
  2130. logits = self.qa_outputs(sequence_output)
  2131. start_logits, end_logits = logits.split(1, dim=-1)
  2132. start_logits = start_logits.squeeze(-1).contiguous()
  2133. end_logits = end_logits.squeeze(-1).contiguous()
  2134. total_loss = None
  2135. if start_positions is not None and end_positions is not None:
  2136. # If we are on multi-GPU, split add a dimension
  2137. if len(start_positions.size()) > 1:
  2138. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  2139. if len(end_positions.size()) > 1:
  2140. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  2141. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  2142. ignored_index = start_logits.size(1)
  2143. start_positions = start_positions.clamp(0, ignored_index)
  2144. end_positions = end_positions.clamp(0, ignored_index)
  2145. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  2146. start_loss = loss_fct(start_logits, start_positions)
  2147. end_loss = loss_fct(end_logits, end_positions)
  2148. total_loss = (start_loss + end_loss) / 2
  2149. if not return_dict:
  2150. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  2151. return ((total_loss,) + output) if total_loss is not None else output
  2152. return Seq2SeqQuestionAnsweringModelOutput(
  2153. loss=total_loss,
  2154. start_logits=start_logits,
  2155. end_logits=end_logits,
  2156. past_key_values=decoder_outputs.past_key_values,
  2157. decoder_hidden_states=decoder_outputs.hidden_states,
  2158. decoder_attentions=decoder_outputs.attentions,
  2159. cross_attentions=decoder_outputs.cross_attentions,
  2160. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2161. encoder_hidden_states=encoder_outputs.hidden_states,
  2162. encoder_attentions=encoder_outputs.attentions,
  2163. )