modeling_funnel.py 68 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594
  1. # coding=utf-8
  2. # Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Funnel Transformer model."""
  16. import os
  17. from dataclasses import dataclass
  18. from typing import List, Optional, Tuple, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import (
  34. ModelOutput,
  35. add_code_sample_docstrings,
  36. add_start_docstrings,
  37. add_start_docstrings_to_model_forward,
  38. logging,
  39. replace_return_docstrings,
  40. )
  41. from .configuration_funnel import FunnelConfig
  42. logger = logging.get_logger(__name__)
  43. _CONFIG_FOR_DOC = "FunnelConfig"
  44. _CHECKPOINT_FOR_DOC = "funnel-transformer/small"
  45. INF = 1e6
  46. def load_tf_weights_in_funnel(model, config, tf_checkpoint_path):
  47. """Load tf checkpoints in a pytorch model."""
  48. try:
  49. import re
  50. import numpy as np
  51. import tensorflow as tf
  52. except ImportError:
  53. logger.error(
  54. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  55. "https://www.tensorflow.org/install/ for installation instructions."
  56. )
  57. raise
  58. tf_path = os.path.abspath(tf_checkpoint_path)
  59. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  60. # Load weights from TF model
  61. init_vars = tf.train.list_variables(tf_path)
  62. names = []
  63. arrays = []
  64. for name, shape in init_vars:
  65. logger.info(f"Loading TF weight {name} with shape {shape}")
  66. array = tf.train.load_variable(tf_path, name)
  67. names.append(name)
  68. arrays.append(array)
  69. _layer_map = {
  70. "k": "k_head",
  71. "q": "q_head",
  72. "v": "v_head",
  73. "o": "post_proj",
  74. "layer_1": "linear_1",
  75. "layer_2": "linear_2",
  76. "rel_attn": "attention",
  77. "ff": "ffn",
  78. "kernel": "weight",
  79. "gamma": "weight",
  80. "beta": "bias",
  81. "lookup_table": "weight",
  82. "word_embedding": "word_embeddings",
  83. "input": "embeddings",
  84. }
  85. for name, array in zip(names, arrays):
  86. name = name.split("/")
  87. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  88. # which are not required for using pretrained model
  89. if any(
  90. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  91. for n in name
  92. ):
  93. logger.info(f"Skipping {'/'.join(name)}")
  94. continue
  95. if name[0] == "generator":
  96. continue
  97. pointer = model
  98. skipped = False
  99. for m_name in name[1:]:
  100. if not isinstance(pointer, FunnelPositionwiseFFN) and re.fullmatch(r"layer_\d+", m_name):
  101. layer_index = int(re.search(r"layer_(\d+)", m_name).groups()[0])
  102. if layer_index < config.num_hidden_layers:
  103. block_idx = 0
  104. while layer_index >= config.block_sizes[block_idx]:
  105. layer_index -= config.block_sizes[block_idx]
  106. block_idx += 1
  107. pointer = pointer.blocks[block_idx][layer_index]
  108. else:
  109. layer_index -= config.num_hidden_layers
  110. pointer = pointer.layers[layer_index]
  111. elif m_name == "r" and isinstance(pointer, FunnelRelMultiheadAttention):
  112. pointer = pointer.r_kernel
  113. break
  114. elif m_name in _layer_map:
  115. pointer = getattr(pointer, _layer_map[m_name])
  116. else:
  117. try:
  118. pointer = getattr(pointer, m_name)
  119. except AttributeError:
  120. print(f"Skipping {'/'.join(name)}", array.shape)
  121. skipped = True
  122. break
  123. if not skipped:
  124. if len(pointer.shape) != len(array.shape):
  125. array = array.reshape(pointer.shape)
  126. if m_name == "kernel":
  127. array = np.transpose(array)
  128. pointer.data = torch.from_numpy(array)
  129. return model
  130. class FunnelEmbeddings(nn.Module):
  131. def __init__(self, config: FunnelConfig) -> None:
  132. super().__init__()
  133. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  134. self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
  135. self.dropout = nn.Dropout(config.hidden_dropout)
  136. def forward(
  137. self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
  138. ) -> torch.Tensor:
  139. if inputs_embeds is None:
  140. inputs_embeds = self.word_embeddings(input_ids)
  141. embeddings = self.layer_norm(inputs_embeds)
  142. embeddings = self.dropout(embeddings)
  143. return embeddings
  144. class FunnelAttentionStructure(nn.Module):
  145. """
  146. Contains helpers for `FunnelRelMultiheadAttention `.
  147. """
  148. cls_token_type_id: int = 2
  149. def __init__(self, config: FunnelConfig) -> None:
  150. super().__init__()
  151. self.config = config
  152. self.sin_dropout = nn.Dropout(config.hidden_dropout)
  153. self.cos_dropout = nn.Dropout(config.hidden_dropout)
  154. # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was
  155. # divided.
  156. self.pooling_mult = None
  157. def init_attention_inputs(
  158. self,
  159. inputs_embeds: torch.Tensor,
  160. attention_mask: Optional[torch.Tensor] = None,
  161. token_type_ids: Optional[torch.Tensor] = None,
  162. ) -> Tuple[torch.Tensor]:
  163. """Returns the attention inputs associated to the inputs of the model."""
  164. # inputs_embeds has shape batch_size x seq_len x d_model
  165. # attention_mask and token_type_ids have shape batch_size x seq_len
  166. self.pooling_mult = 1
  167. self.seq_len = seq_len = inputs_embeds.size(1)
  168. position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
  169. token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
  170. cls_mask = (
  171. nn.functional.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
  172. if self.config.separate_cls
  173. else None
  174. )
  175. return (position_embeds, token_type_mat, attention_mask, cls_mask)
  176. def token_type_ids_to_mat(self, token_type_ids: torch.Tensor) -> torch.Tensor:
  177. """Convert `token_type_ids` to `token_type_mat`."""
  178. token_type_mat = token_type_ids[:, :, None] == token_type_ids[:, None]
  179. # Treat <cls> as in the same segment as both A & B
  180. cls_ids = token_type_ids == self.cls_token_type_id
  181. cls_mat = cls_ids[:, :, None] | cls_ids[:, None]
  182. return cls_mat | token_type_mat
  183. def get_position_embeds(
  184. self, seq_len: int, dtype: torch.dtype, device: torch.device
  185. ) -> Union[Tuple[torch.Tensor], List[List[torch.Tensor]]]:
  186. """
  187. Create and cache inputs related to relative position encoding. Those are very different depending on whether we
  188. are using the factorized or the relative shift attention:
  189. For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
  190. final formula.
  191. For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
  192. formula.
  193. Paper link: https://arxiv.org/abs/2006.03236
  194. """
  195. d_model = self.config.d_model
  196. if self.config.attention_type == "factorized":
  197. # Notations from the paper, appending A.2.2, final formula.
  198. # We need to create and return the matrices phi, psi, pi and omega.
  199. pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
  200. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  201. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  202. sinusoid = pos_seq[:, None] * inv_freq[None]
  203. sin_embed = torch.sin(sinusoid)
  204. sin_embed_d = self.sin_dropout(sin_embed)
  205. cos_embed = torch.cos(sinusoid)
  206. cos_embed_d = self.cos_dropout(cos_embed)
  207. # This is different from the formula on the paper...
  208. phi = torch.cat([sin_embed_d, sin_embed_d], dim=-1)
  209. psi = torch.cat([cos_embed, sin_embed], dim=-1)
  210. pi = torch.cat([cos_embed_d, cos_embed_d], dim=-1)
  211. omega = torch.cat([-sin_embed, cos_embed], dim=-1)
  212. return (phi, pi, psi, omega)
  213. else:
  214. # Notations from the paper, appending A.2.1, final formula.
  215. # We need to create and return all the possible vectors R for all blocks and shifts.
  216. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  217. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  218. # Maximum relative positions for the first input
  219. rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  220. zero_offset = seq_len * 2
  221. sinusoid = rel_pos_id[:, None] * inv_freq[None]
  222. sin_embed = self.sin_dropout(torch.sin(sinusoid))
  223. cos_embed = self.cos_dropout(torch.cos(sinusoid))
  224. pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)
  225. pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
  226. pooled_pos = pos
  227. position_embeds_list = []
  228. for block_index in range(0, self.config.num_blocks):
  229. # For each block with block_index > 0, we need two types position embeddings:
  230. # - Attention(pooled-q, unpooled-kv)
  231. # - Attention(pooled-q, pooled-kv)
  232. # For block_index = 0 we only need the second one and leave the first one as None.
  233. # First type
  234. if block_index == 0:
  235. position_embeds_pooling = None
  236. else:
  237. pooled_pos = self.stride_pool_pos(pos, block_index)
  238. # construct rel_pos_id
  239. stride = 2 ** (block_index - 1)
  240. rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
  241. rel_pos = rel_pos[:, None] + zero_offset
  242. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  243. position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)
  244. # Second type
  245. pos = pooled_pos
  246. stride = 2**block_index
  247. rel_pos = self.relative_pos(pos, stride)
  248. rel_pos = rel_pos[:, None] + zero_offset
  249. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  250. position_embeds_no_pooling = torch.gather(pos_embed, 0, rel_pos)
  251. position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
  252. return position_embeds_list
  253. def stride_pool_pos(self, pos_id: torch.Tensor, block_index: int):
  254. """
  255. Pool `pos_id` while keeping the cls token separate (if `config.separate_cls=True`).
  256. """
  257. if self.config.separate_cls:
  258. # Under separate <cls>, we treat the <cls> as the first token in
  259. # the previous block of the 1st real block. Since the 1st real
  260. # block always has position 1, the position of the previous block
  261. # will be at `1 - 2 ** block_index`.
  262. cls_pos = pos_id.new_tensor([-(2**block_index) + 1])
  263. pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:]
  264. return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
  265. else:
  266. return pos_id[::2]
  267. def relative_pos(self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1) -> torch.Tensor:
  268. """
  269. Build the relative positional vector between `pos` and `pooled_pos`.
  270. """
  271. if pooled_pos is None:
  272. pooled_pos = pos
  273. ref_point = pooled_pos[0] - pos[0]
  274. num_remove = shift * len(pooled_pos)
  275. max_dist = ref_point + num_remove * stride
  276. min_dist = pooled_pos[0] - pos[-1]
  277. return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)
  278. def stride_pool(
  279. self,
  280. tensor: Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]],
  281. axis: Union[int, Tuple[int], List[int]],
  282. ) -> torch.Tensor:
  283. """
  284. Perform pooling by stride slicing the tensor along the given axis.
  285. """
  286. if tensor is None:
  287. return None
  288. # Do the stride pool recursively if axis is a list or a tuple of ints.
  289. if isinstance(axis, (list, tuple)):
  290. for ax in axis:
  291. tensor = self.stride_pool(tensor, ax)
  292. return tensor
  293. # Do the stride pool recursively if tensor is a list or tuple of tensors.
  294. if isinstance(tensor, (tuple, list)):
  295. return type(tensor)(self.stride_pool(x, axis) for x in tensor)
  296. # Deal with negative axis
  297. axis %= tensor.ndim
  298. axis_slice = (
  299. slice(None, -1, 2) if self.config.separate_cls and self.config.truncate_seq else slice(None, None, 2)
  300. )
  301. enc_slice = [slice(None)] * axis + [axis_slice]
  302. if self.config.separate_cls:
  303. cls_slice = [slice(None)] * axis + [slice(None, 1)]
  304. tensor = torch.cat([tensor[cls_slice], tensor], axis=axis)
  305. return tensor[enc_slice]
  306. def pool_tensor(
  307. self, tensor: Union[torch.Tensor, Tuple[torch.Tensor], List[torch.Tensor]], mode: str = "mean", stride: int = 2
  308. ) -> torch.Tensor:
  309. """Apply 1D pooling to a tensor of size [B x T (x H)]."""
  310. if tensor is None:
  311. return None
  312. # Do the pool recursively if tensor is a list or tuple of tensors.
  313. if isinstance(tensor, (tuple, list)):
  314. return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)
  315. if self.config.separate_cls:
  316. suffix = tensor[:, :-1] if self.config.truncate_seq else tensor
  317. tensor = torch.cat([tensor[:, :1], suffix], dim=1)
  318. ndim = tensor.ndim
  319. if ndim == 2:
  320. tensor = tensor[:, None, :, None]
  321. elif ndim == 3:
  322. tensor = tensor[:, None, :, :]
  323. # Stride is applied on the second-to-last dimension.
  324. stride = (stride, 1)
  325. if mode == "mean":
  326. tensor = nn.functional.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  327. elif mode == "max":
  328. tensor = nn.functional.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  329. elif mode == "min":
  330. tensor = -nn.functional.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
  331. else:
  332. raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
  333. if ndim == 2:
  334. return tensor[:, 0, :, 0]
  335. elif ndim == 3:
  336. return tensor[:, 0]
  337. return tensor
  338. def pre_attention_pooling(
  339. self, output, attention_inputs: Tuple[torch.Tensor]
  340. ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
  341. """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
  342. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  343. if self.config.pool_q_only:
  344. if self.config.attention_type == "factorized":
  345. position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
  346. token_type_mat = self.stride_pool(token_type_mat, 1)
  347. cls_mask = self.stride_pool(cls_mask, 0)
  348. output = self.pool_tensor(output, mode=self.config.pooling_type)
  349. else:
  350. self.pooling_mult *= 2
  351. if self.config.attention_type == "factorized":
  352. position_embeds = self.stride_pool(position_embeds, 0)
  353. token_type_mat = self.stride_pool(token_type_mat, [1, 2])
  354. cls_mask = self.stride_pool(cls_mask, [1, 2])
  355. attention_mask = self.pool_tensor(attention_mask, mode="min")
  356. output = self.pool_tensor(output, mode=self.config.pooling_type)
  357. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  358. return output, attention_inputs
  359. def post_attention_pooling(self, attention_inputs: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
  360. """Pool the proper parts of `attention_inputs` after the attention layer."""
  361. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  362. if self.config.pool_q_only:
  363. self.pooling_mult *= 2
  364. if self.config.attention_type == "factorized":
  365. position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0)
  366. token_type_mat = self.stride_pool(token_type_mat, 2)
  367. cls_mask = self.stride_pool(cls_mask, 1)
  368. attention_mask = self.pool_tensor(attention_mask, mode="min")
  369. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  370. return attention_inputs
  371. def _relative_shift_gather(positional_attn: torch.Tensor, context_len: int, shift: int) -> torch.Tensor:
  372. batch_size, n_head, seq_len, max_rel_len = positional_attn.shape
  373. # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
  374. # What's next is the same as doing the following gather, which might be clearer code but less efficient.
  375. # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
  376. # # matrix of context_len + i-j
  377. # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
  378. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
  379. positional_attn = positional_attn[:, :, shift:, :]
  380. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
  381. positional_attn = positional_attn[..., :context_len]
  382. return positional_attn
  383. class FunnelRelMultiheadAttention(nn.Module):
  384. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  385. super().__init__()
  386. self.config = config
  387. self.block_index = block_index
  388. d_model, n_head, d_head = config.d_model, config.n_head, config.d_head
  389. self.hidden_dropout = nn.Dropout(config.hidden_dropout)
  390. self.attention_dropout = nn.Dropout(config.attention_dropout)
  391. self.q_head = nn.Linear(d_model, n_head * d_head, bias=False)
  392. self.k_head = nn.Linear(d_model, n_head * d_head)
  393. self.v_head = nn.Linear(d_model, n_head * d_head)
  394. self.r_w_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  395. self.r_r_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  396. self.r_kernel = nn.Parameter(torch.zeros([d_model, n_head, d_head]))
  397. self.r_s_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  398. self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))
  399. self.post_proj = nn.Linear(n_head * d_head, d_model)
  400. self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
  401. self.scale = 1.0 / (d_head**0.5)
  402. def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
  403. """Relative attention score for the positional encodings"""
  404. # q_head has shape batch_size x sea_len x n_head x d_head
  405. if self.config.attention_type == "factorized":
  406. # Notations from the paper, appending A.2.2, final formula (https://arxiv.org/abs/2006.03236)
  407. # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model
  408. phi, pi, psi, omega = position_embeds
  409. # Shape n_head x d_head
  410. u = self.r_r_bias * self.scale
  411. # Shape d_model x n_head x d_head
  412. w_r = self.r_kernel
  413. # Shape batch_size x sea_len x n_head x d_model
  414. q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
  415. q_r_attention_1 = q_r_attention * phi[:, None]
  416. q_r_attention_2 = q_r_attention * pi[:, None]
  417. # Shape batch_size x n_head x seq_len x context_len
  418. positional_attn = torch.einsum("bind,jd->bnij", q_r_attention_1, psi) + torch.einsum(
  419. "bind,jd->bnij", q_r_attention_2, omega
  420. )
  421. else:
  422. shift = 2 if q_head.shape[1] != context_len else 1
  423. # Notations from the paper, appending A.2.1, final formula (https://arxiv.org/abs/2006.03236)
  424. # Grab the proper positional encoding, shape max_rel_len x d_model
  425. r = position_embeds[self.block_index][shift - 1]
  426. # Shape n_head x d_head
  427. v = self.r_r_bias * self.scale
  428. # Shape d_model x n_head x d_head
  429. w_r = self.r_kernel
  430. # Shape max_rel_len x n_head x d_model
  431. r_head = torch.einsum("td,dnh->tnh", r, w_r)
  432. # Shape batch_size x n_head x seq_len x max_rel_len
  433. positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
  434. # Shape batch_size x n_head x seq_len x context_len
  435. positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
  436. if cls_mask is not None:
  437. positional_attn *= cls_mask
  438. return positional_attn
  439. def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
  440. """Relative attention score for the token_type_ids"""
  441. if token_type_mat is None:
  442. return 0
  443. batch_size, seq_len, context_len = token_type_mat.shape
  444. # q_head has shape batch_size x seq_len x n_head x d_head
  445. # Shape n_head x d_head
  446. r_s_bias = self.r_s_bias * self.scale
  447. # Shape batch_size x n_head x seq_len x 2
  448. token_type_bias = torch.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
  449. # Shape batch_size x n_head x seq_len x context_len
  450. token_type_mat = token_type_mat[:, None].expand([batch_size, q_head.shape[2], seq_len, context_len])
  451. # Shapes batch_size x n_head x seq_len
  452. diff_token_type, same_token_type = torch.split(token_type_bias, 1, dim=-1)
  453. # Shape batch_size x n_head x seq_len x context_len
  454. token_type_attn = torch.where(
  455. token_type_mat, same_token_type.expand(token_type_mat.shape), diff_token_type.expand(token_type_mat.shape)
  456. )
  457. if cls_mask is not None:
  458. token_type_attn *= cls_mask
  459. return token_type_attn
  460. def forward(
  461. self,
  462. query: torch.Tensor,
  463. key: torch.Tensor,
  464. value: torch.Tensor,
  465. attention_inputs: Tuple[torch.Tensor],
  466. output_attentions: bool = False,
  467. ) -> Tuple[torch.Tensor, ...]:
  468. # query has shape batch_size x seq_len x d_model
  469. # key and value have shapes batch_size x context_len x d_model
  470. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  471. batch_size, seq_len, _ = query.shape
  472. context_len = key.shape[1]
  473. n_head, d_head = self.config.n_head, self.config.d_head
  474. # Shape batch_size x seq_len x n_head x d_head
  475. q_head = self.q_head(query).view(batch_size, seq_len, n_head, d_head)
  476. # Shapes batch_size x context_len x n_head x d_head
  477. k_head = self.k_head(key).view(batch_size, context_len, n_head, d_head)
  478. v_head = self.v_head(value).view(batch_size, context_len, n_head, d_head)
  479. q_head = q_head * self.scale
  480. # Shape n_head x d_head
  481. r_w_bias = self.r_w_bias * self.scale
  482. # Shapes batch_size x n_head x seq_len x context_len
  483. content_score = torch.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
  484. positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask)
  485. token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
  486. # merge attention scores
  487. attn_score = content_score + positional_attn + token_type_attn
  488. # precision safe in case of mixed precision training
  489. dtype = attn_score.dtype
  490. attn_score = attn_score.float()
  491. # perform masking
  492. if attention_mask is not None:
  493. attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float())
  494. # attention probability
  495. attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
  496. attn_prob = self.attention_dropout(attn_prob)
  497. # attention output, shape batch_size x seq_len x n_head x d_head
  498. attn_vec = torch.einsum("bnij,bjnd->bind", attn_prob, v_head)
  499. # Shape shape batch_size x seq_len x d_model
  500. attn_out = self.post_proj(attn_vec.reshape(batch_size, seq_len, n_head * d_head))
  501. attn_out = self.hidden_dropout(attn_out)
  502. output = self.layer_norm(query + attn_out)
  503. return (output, attn_prob) if output_attentions else (output,)
  504. class FunnelPositionwiseFFN(nn.Module):
  505. def __init__(self, config: FunnelConfig) -> None:
  506. super().__init__()
  507. self.linear_1 = nn.Linear(config.d_model, config.d_inner)
  508. self.activation_function = ACT2FN[config.hidden_act]
  509. self.activation_dropout = nn.Dropout(config.activation_dropout)
  510. self.linear_2 = nn.Linear(config.d_inner, config.d_model)
  511. self.dropout = nn.Dropout(config.hidden_dropout)
  512. self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
  513. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  514. h = self.linear_1(hidden)
  515. h = self.activation_function(h)
  516. h = self.activation_dropout(h)
  517. h = self.linear_2(h)
  518. h = self.dropout(h)
  519. return self.layer_norm(hidden + h)
  520. class FunnelLayer(nn.Module):
  521. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  522. super().__init__()
  523. self.attention = FunnelRelMultiheadAttention(config, block_index)
  524. self.ffn = FunnelPositionwiseFFN(config)
  525. def forward(
  526. self,
  527. query: torch.Tensor,
  528. key: torch.Tensor,
  529. value: torch.Tensor,
  530. attention_inputs,
  531. output_attentions: bool = False,
  532. ) -> Tuple:
  533. attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions)
  534. output = self.ffn(attn[0])
  535. return (output, attn[1]) if output_attentions else (output,)
  536. class FunnelEncoder(nn.Module):
  537. def __init__(self, config: FunnelConfig) -> None:
  538. super().__init__()
  539. self.config = config
  540. self.attention_structure = FunnelAttentionStructure(config)
  541. self.blocks = nn.ModuleList(
  542. [
  543. nn.ModuleList([FunnelLayer(config, block_index) for _ in range(block_size)])
  544. for block_index, block_size in enumerate(config.block_sizes)
  545. ]
  546. )
  547. def forward(
  548. self,
  549. inputs_embeds: torch.Tensor,
  550. attention_mask: Optional[torch.Tensor] = None,
  551. token_type_ids: Optional[torch.Tensor] = None,
  552. output_attentions: bool = False,
  553. output_hidden_states: bool = False,
  554. return_dict: bool = True,
  555. ) -> Union[Tuple, BaseModelOutput]:
  556. # The pooling is not implemented on long tensors, so we convert this mask.
  557. attention_mask = attention_mask.type_as(inputs_embeds)
  558. attention_inputs = self.attention_structure.init_attention_inputs(
  559. inputs_embeds,
  560. attention_mask=attention_mask,
  561. token_type_ids=token_type_ids,
  562. )
  563. hidden = inputs_embeds
  564. all_hidden_states = (inputs_embeds,) if output_hidden_states else None
  565. all_attentions = () if output_attentions else None
  566. for block_index, block in enumerate(self.blocks):
  567. pooling_flag = hidden.size(1) > (2 if self.config.separate_cls else 1)
  568. pooling_flag = pooling_flag and block_index > 0
  569. if pooling_flag:
  570. pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
  571. hidden, attention_inputs
  572. )
  573. for layer_index, layer in enumerate(block):
  574. for repeat_index in range(self.config.block_repeats[block_index]):
  575. do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
  576. if do_pooling:
  577. query = pooled_hidden
  578. key = value = hidden if self.config.pool_q_only else pooled_hidden
  579. else:
  580. query = key = value = hidden
  581. layer_output = layer(query, key, value, attention_inputs, output_attentions=output_attentions)
  582. hidden = layer_output[0]
  583. if do_pooling:
  584. attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs)
  585. if output_attentions:
  586. all_attentions = all_attentions + layer_output[1:]
  587. if output_hidden_states:
  588. all_hidden_states = all_hidden_states + (hidden,)
  589. if not return_dict:
  590. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  591. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  592. def upsample(
  593. x: torch.Tensor, stride: int, target_len: int, separate_cls: bool = True, truncate_seq: bool = False
  594. ) -> torch.Tensor:
  595. """
  596. Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension.
  597. """
  598. if stride == 1:
  599. return x
  600. if separate_cls:
  601. cls = x[:, :1]
  602. x = x[:, 1:]
  603. output = torch.repeat_interleave(x, repeats=stride, dim=1)
  604. if separate_cls:
  605. if truncate_seq:
  606. output = nn.functional.pad(output, (0, 0, 0, stride - 1, 0, 0))
  607. output = output[:, : target_len - 1]
  608. output = torch.cat([cls, output], dim=1)
  609. else:
  610. output = output[:, :target_len]
  611. return output
  612. class FunnelDecoder(nn.Module):
  613. def __init__(self, config: FunnelConfig) -> None:
  614. super().__init__()
  615. self.config = config
  616. self.attention_structure = FunnelAttentionStructure(config)
  617. self.layers = nn.ModuleList([FunnelLayer(config, 0) for _ in range(config.num_decoder_layers)])
  618. def forward(
  619. self,
  620. final_hidden: torch.Tensor,
  621. first_block_hidden: torch.Tensor,
  622. attention_mask: Optional[torch.Tensor] = None,
  623. token_type_ids: Optional[torch.Tensor] = None,
  624. output_attentions: bool = False,
  625. output_hidden_states: bool = False,
  626. return_dict: bool = True,
  627. ) -> Union[Tuple, BaseModelOutput]:
  628. upsampled_hidden = upsample(
  629. final_hidden,
  630. stride=2 ** (len(self.config.block_sizes) - 1),
  631. target_len=first_block_hidden.shape[1],
  632. separate_cls=self.config.separate_cls,
  633. truncate_seq=self.config.truncate_seq,
  634. )
  635. hidden = upsampled_hidden + first_block_hidden
  636. all_hidden_states = (hidden,) if output_hidden_states else None
  637. all_attentions = () if output_attentions else None
  638. attention_inputs = self.attention_structure.init_attention_inputs(
  639. hidden,
  640. attention_mask=attention_mask,
  641. token_type_ids=token_type_ids,
  642. )
  643. for layer in self.layers:
  644. layer_output = layer(hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions)
  645. hidden = layer_output[0]
  646. if output_attentions:
  647. all_attentions = all_attentions + layer_output[1:]
  648. if output_hidden_states:
  649. all_hidden_states = all_hidden_states + (hidden,)
  650. if not return_dict:
  651. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  652. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  653. class FunnelDiscriminatorPredictions(nn.Module):
  654. """Prediction module for the discriminator, made up of two dense layers."""
  655. def __init__(self, config: FunnelConfig) -> None:
  656. super().__init__()
  657. self.config = config
  658. self.dense = nn.Linear(config.d_model, config.d_model)
  659. self.dense_prediction = nn.Linear(config.d_model, 1)
  660. def forward(self, discriminator_hidden_states: torch.Tensor) -> torch.Tensor:
  661. hidden_states = self.dense(discriminator_hidden_states)
  662. hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
  663. logits = self.dense_prediction(hidden_states).squeeze(-1)
  664. return logits
  665. class FunnelPreTrainedModel(PreTrainedModel):
  666. """
  667. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  668. models.
  669. """
  670. config_class = FunnelConfig
  671. load_tf_weights = load_tf_weights_in_funnel
  672. base_model_prefix = "funnel"
  673. def _init_weights(self, module):
  674. classname = module.__class__.__name__
  675. if classname.find("Linear") != -1:
  676. if getattr(module, "weight", None) is not None:
  677. if self.config.initializer_std is None:
  678. fan_out, fan_in = module.weight.shape
  679. std = np.sqrt(1.0 / float(fan_in + fan_out))
  680. else:
  681. std = self.config.initializer_std
  682. nn.init.normal_(module.weight, std=std)
  683. if getattr(module, "bias", None) is not None:
  684. nn.init.constant_(module.bias, 0.0)
  685. elif classname == "FunnelRelMultiheadAttention":
  686. nn.init.uniform_(module.r_w_bias, b=self.config.initializer_range)
  687. nn.init.uniform_(module.r_r_bias, b=self.config.initializer_range)
  688. nn.init.uniform_(module.r_kernel, b=self.config.initializer_range)
  689. nn.init.uniform_(module.r_s_bias, b=self.config.initializer_range)
  690. nn.init.uniform_(module.seg_embed, b=self.config.initializer_range)
  691. elif classname == "FunnelEmbeddings":
  692. std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
  693. nn.init.normal_(module.word_embeddings.weight, std=std)
  694. if module.word_embeddings.padding_idx is not None:
  695. module.word_embeddings.weight.data[module.word_embeddings.padding_idx].zero_()
  696. class FunnelClassificationHead(nn.Module):
  697. def __init__(self, config: FunnelConfig, n_labels: int) -> None:
  698. super().__init__()
  699. self.linear_hidden = nn.Linear(config.d_model, config.d_model)
  700. self.dropout = nn.Dropout(config.hidden_dropout)
  701. self.linear_out = nn.Linear(config.d_model, n_labels)
  702. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  703. hidden = self.linear_hidden(hidden)
  704. hidden = torch.tanh(hidden)
  705. hidden = self.dropout(hidden)
  706. return self.linear_out(hidden)
  707. @dataclass
  708. class FunnelForPreTrainingOutput(ModelOutput):
  709. """
  710. Output type of [`FunnelForPreTraining`].
  711. Args:
  712. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  713. Total loss of the ELECTRA-style objective.
  714. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  715. Prediction scores of the head (scores for each token before SoftMax).
  716. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  717. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  718. shape `(batch_size, sequence_length, hidden_size)`.
  719. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  720. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  721. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  722. sequence_length)`.
  723. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  724. heads.
  725. """
  726. loss: Optional[torch.FloatTensor] = None
  727. logits: torch.FloatTensor = None
  728. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  729. attentions: Optional[Tuple[torch.FloatTensor]] = None
  730. FUNNEL_START_DOCSTRING = r"""
  731. The Funnel Transformer model was proposed in [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient
  732. Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
  733. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  734. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  735. etc.)
  736. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  737. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  738. and behavior.
  739. Parameters:
  740. config ([`FunnelConfig`]): Model configuration class with all the parameters of the model.
  741. Initializing with a config file does not load the weights associated with the model, only the
  742. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  743. """
  744. FUNNEL_INPUTS_DOCSTRING = r"""
  745. Args:
  746. input_ids (`torch.LongTensor` of shape `({0})`):
  747. Indices of input sequence tokens in the vocabulary.
  748. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  749. [`PreTrainedTokenizer.__call__`] for details.
  750. [What are input IDs?](../glossary#input-ids)
  751. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  752. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  753. - 1 for tokens that are **not masked**,
  754. - 0 for tokens that are **masked**.
  755. [What are attention masks?](../glossary#attention-mask)
  756. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  757. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  758. 1]`:
  759. - 0 corresponds to a *sentence A* token,
  760. - 1 corresponds to a *sentence B* token.
  761. [What are token type IDs?](../glossary#token-type-ids)
  762. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  763. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  764. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  765. model's internal embedding lookup matrix.
  766. output_attentions (`bool`, *optional*):
  767. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  768. tensors for more detail.
  769. output_hidden_states (`bool`, *optional*):
  770. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  771. more detail.
  772. return_dict (`bool`, *optional*):
  773. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  774. """
  775. @add_start_docstrings(
  776. """
  777. The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called
  778. decoder) or any task-specific head on top.
  779. """,
  780. FUNNEL_START_DOCSTRING,
  781. )
  782. class FunnelBaseModel(FunnelPreTrainedModel):
  783. def __init__(self, config: FunnelConfig) -> None:
  784. super().__init__(config)
  785. self.embeddings = FunnelEmbeddings(config)
  786. self.encoder = FunnelEncoder(config)
  787. # Initialize weights and apply final processing
  788. self.post_init()
  789. def get_input_embeddings(self) -> nn.Embedding:
  790. return self.embeddings.word_embeddings
  791. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  792. self.embeddings.word_embeddings = new_embeddings
  793. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  794. @add_code_sample_docstrings(
  795. checkpoint="funnel-transformer/small-base",
  796. output_type=BaseModelOutput,
  797. config_class=_CONFIG_FOR_DOC,
  798. )
  799. def forward(
  800. self,
  801. input_ids: Optional[torch.Tensor] = None,
  802. attention_mask: Optional[torch.Tensor] = None,
  803. token_type_ids: Optional[torch.Tensor] = None,
  804. position_ids: Optional[torch.Tensor] = None,
  805. head_mask: Optional[torch.Tensor] = None,
  806. inputs_embeds: Optional[torch.Tensor] = None,
  807. output_attentions: Optional[bool] = None,
  808. output_hidden_states: Optional[bool] = None,
  809. return_dict: Optional[bool] = None,
  810. ) -> Union[Tuple, BaseModelOutput]:
  811. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  812. output_hidden_states = (
  813. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  814. )
  815. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  816. if input_ids is not None and inputs_embeds is not None:
  817. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  818. elif input_ids is not None:
  819. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  820. input_shape = input_ids.size()
  821. elif inputs_embeds is not None:
  822. input_shape = inputs_embeds.size()[:-1]
  823. else:
  824. raise ValueError("You have to specify either input_ids or inputs_embeds")
  825. device = input_ids.device if input_ids is not None else inputs_embeds.device
  826. if attention_mask is None:
  827. attention_mask = torch.ones(input_shape, device=device)
  828. if token_type_ids is None:
  829. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  830. # TODO: deal with head_mask
  831. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  832. encoder_outputs = self.encoder(
  833. inputs_embeds,
  834. attention_mask=attention_mask,
  835. token_type_ids=token_type_ids,
  836. output_attentions=output_attentions,
  837. output_hidden_states=output_hidden_states,
  838. return_dict=return_dict,
  839. )
  840. return encoder_outputs
  841. @add_start_docstrings(
  842. "The bare Funnel Transformer Model transformer outputting raw hidden-states without any specific head on top.",
  843. FUNNEL_START_DOCSTRING,
  844. )
  845. class FunnelModel(FunnelPreTrainedModel):
  846. def __init__(self, config: FunnelConfig) -> None:
  847. super().__init__(config)
  848. self.config = config
  849. self.embeddings = FunnelEmbeddings(config)
  850. self.encoder = FunnelEncoder(config)
  851. self.decoder = FunnelDecoder(config)
  852. # Initialize weights and apply final processing
  853. self.post_init()
  854. def get_input_embeddings(self) -> nn.Embedding:
  855. return self.embeddings.word_embeddings
  856. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  857. self.embeddings.word_embeddings = new_embeddings
  858. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  859. @add_code_sample_docstrings(
  860. checkpoint=_CHECKPOINT_FOR_DOC,
  861. output_type=BaseModelOutput,
  862. config_class=_CONFIG_FOR_DOC,
  863. )
  864. def forward(
  865. self,
  866. input_ids: Optional[torch.Tensor] = None,
  867. attention_mask: Optional[torch.Tensor] = None,
  868. token_type_ids: Optional[torch.Tensor] = None,
  869. inputs_embeds: Optional[torch.Tensor] = None,
  870. output_attentions: Optional[bool] = None,
  871. output_hidden_states: Optional[bool] = None,
  872. return_dict: Optional[bool] = None,
  873. ) -> Union[Tuple, BaseModelOutput]:
  874. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  875. output_hidden_states = (
  876. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  877. )
  878. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  879. if input_ids is not None and inputs_embeds is not None:
  880. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  881. elif input_ids is not None:
  882. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  883. input_shape = input_ids.size()
  884. elif inputs_embeds is not None:
  885. input_shape = inputs_embeds.size()[:-1]
  886. else:
  887. raise ValueError("You have to specify either input_ids or inputs_embeds")
  888. device = input_ids.device if input_ids is not None else inputs_embeds.device
  889. if attention_mask is None:
  890. attention_mask = torch.ones(input_shape, device=device)
  891. if token_type_ids is None:
  892. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  893. # TODO: deal with head_mask
  894. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  895. encoder_outputs = self.encoder(
  896. inputs_embeds,
  897. attention_mask=attention_mask,
  898. token_type_ids=token_type_ids,
  899. output_attentions=output_attentions,
  900. output_hidden_states=True,
  901. return_dict=return_dict,
  902. )
  903. decoder_outputs = self.decoder(
  904. final_hidden=encoder_outputs[0],
  905. first_block_hidden=encoder_outputs[1][self.config.block_sizes[0]],
  906. attention_mask=attention_mask,
  907. token_type_ids=token_type_ids,
  908. output_attentions=output_attentions,
  909. output_hidden_states=output_hidden_states,
  910. return_dict=return_dict,
  911. )
  912. if not return_dict:
  913. idx = 0
  914. outputs = (decoder_outputs[0],)
  915. if output_hidden_states:
  916. idx += 1
  917. outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
  918. if output_attentions:
  919. idx += 1
  920. outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
  921. return outputs
  922. return BaseModelOutput(
  923. last_hidden_state=decoder_outputs[0],
  924. hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
  925. if output_hidden_states
  926. else None,
  927. attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
  928. )
  929. add_start_docstrings(
  930. """
  931. Funnel Transformer model with a binary classification head on top as used during pretraining for identifying
  932. generated tokens.
  933. """,
  934. FUNNEL_START_DOCSTRING,
  935. )
  936. class FunnelForPreTraining(FunnelPreTrainedModel):
  937. def __init__(self, config: FunnelConfig) -> None:
  938. super().__init__(config)
  939. self.funnel = FunnelModel(config)
  940. self.discriminator_predictions = FunnelDiscriminatorPredictions(config)
  941. # Initialize weights and apply final processing
  942. self.post_init()
  943. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  944. @replace_return_docstrings(output_type=FunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  945. def forward(
  946. self,
  947. input_ids: Optional[torch.Tensor] = None,
  948. attention_mask: Optional[torch.Tensor] = None,
  949. token_type_ids: Optional[torch.Tensor] = None,
  950. inputs_embeds: Optional[torch.Tensor] = None,
  951. labels: Optional[torch.Tensor] = None,
  952. output_attentions: Optional[bool] = None,
  953. output_hidden_states: Optional[bool] = None,
  954. return_dict: Optional[bool] = None,
  955. ) -> Union[Tuple, FunnelForPreTrainingOutput]:
  956. r"""
  957. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  958. Labels for computing the ELECTRA-style loss. Input should be a sequence of tokens (see `input_ids`
  959. docstring) Indices should be in `[0, 1]`:
  960. - 0 indicates the token is an original token,
  961. - 1 indicates the token was replaced.
  962. Returns:
  963. Examples:
  964. ```python
  965. >>> from transformers import AutoTokenizer, FunnelForPreTraining
  966. >>> import torch
  967. >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small")
  968. >>> model = FunnelForPreTraining.from_pretrained("funnel-transformer/small")
  969. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  970. >>> logits = model(**inputs).logits
  971. ```"""
  972. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  973. discriminator_hidden_states = self.funnel(
  974. input_ids,
  975. attention_mask=attention_mask,
  976. token_type_ids=token_type_ids,
  977. inputs_embeds=inputs_embeds,
  978. output_attentions=output_attentions,
  979. output_hidden_states=output_hidden_states,
  980. return_dict=return_dict,
  981. )
  982. discriminator_sequence_output = discriminator_hidden_states[0]
  983. logits = self.discriminator_predictions(discriminator_sequence_output)
  984. loss = None
  985. if labels is not None:
  986. loss_fct = nn.BCEWithLogitsLoss()
  987. if attention_mask is not None:
  988. active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
  989. active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
  990. active_labels = labels[active_loss]
  991. loss = loss_fct(active_logits, active_labels.float())
  992. else:
  993. loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
  994. if not return_dict:
  995. output = (logits,) + discriminator_hidden_states[1:]
  996. return ((loss,) + output) if loss is not None else output
  997. return FunnelForPreTrainingOutput(
  998. loss=loss,
  999. logits=logits,
  1000. hidden_states=discriminator_hidden_states.hidden_states,
  1001. attentions=discriminator_hidden_states.attentions,
  1002. )
  1003. @add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
  1004. class FunnelForMaskedLM(FunnelPreTrainedModel):
  1005. _tied_weights_keys = ["lm_head.weight"]
  1006. def __init__(self, config: FunnelConfig) -> None:
  1007. super().__init__(config)
  1008. self.funnel = FunnelModel(config)
  1009. self.lm_head = nn.Linear(config.d_model, config.vocab_size)
  1010. # Initialize weights and apply final processing
  1011. self.post_init()
  1012. def get_output_embeddings(self) -> nn.Linear:
  1013. return self.lm_head
  1014. def set_output_embeddings(self, new_embeddings: nn.Embedding) -> None:
  1015. self.lm_head = new_embeddings
  1016. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1017. @add_code_sample_docstrings(
  1018. checkpoint=_CHECKPOINT_FOR_DOC,
  1019. output_type=MaskedLMOutput,
  1020. config_class=_CONFIG_FOR_DOC,
  1021. mask="<mask>",
  1022. )
  1023. def forward(
  1024. self,
  1025. input_ids: Optional[torch.Tensor] = None,
  1026. attention_mask: Optional[torch.Tensor] = None,
  1027. token_type_ids: Optional[torch.Tensor] = None,
  1028. inputs_embeds: Optional[torch.Tensor] = None,
  1029. labels: Optional[torch.Tensor] = None,
  1030. output_attentions: Optional[bool] = None,
  1031. output_hidden_states: Optional[bool] = None,
  1032. return_dict: Optional[bool] = None,
  1033. ) -> Union[Tuple, MaskedLMOutput]:
  1034. r"""
  1035. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1036. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1037. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1038. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1039. """
  1040. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1041. outputs = self.funnel(
  1042. input_ids,
  1043. attention_mask=attention_mask,
  1044. token_type_ids=token_type_ids,
  1045. inputs_embeds=inputs_embeds,
  1046. output_attentions=output_attentions,
  1047. output_hidden_states=output_hidden_states,
  1048. return_dict=return_dict,
  1049. )
  1050. last_hidden_state = outputs[0]
  1051. prediction_logits = self.lm_head(last_hidden_state)
  1052. masked_lm_loss = None
  1053. if labels is not None:
  1054. loss_fct = CrossEntropyLoss() # -100 index = padding token
  1055. masked_lm_loss = loss_fct(prediction_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1056. if not return_dict:
  1057. output = (prediction_logits,) + outputs[1:]
  1058. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1059. return MaskedLMOutput(
  1060. loss=masked_lm_loss,
  1061. logits=prediction_logits,
  1062. hidden_states=outputs.hidden_states,
  1063. attentions=outputs.attentions,
  1064. )
  1065. @add_start_docstrings(
  1066. """
  1067. Funnel Transformer Model with a sequence classification/regression head on top (two linear layer on top of the
  1068. first timestep of the last hidden state) e.g. for GLUE tasks.
  1069. """,
  1070. FUNNEL_START_DOCSTRING,
  1071. )
  1072. class FunnelForSequenceClassification(FunnelPreTrainedModel):
  1073. def __init__(self, config: FunnelConfig) -> None:
  1074. super().__init__(config)
  1075. self.num_labels = config.num_labels
  1076. self.config = config
  1077. self.funnel = FunnelBaseModel(config)
  1078. self.classifier = FunnelClassificationHead(config, config.num_labels)
  1079. # Initialize weights and apply final processing
  1080. self.post_init()
  1081. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1082. @add_code_sample_docstrings(
  1083. checkpoint="funnel-transformer/small-base",
  1084. output_type=SequenceClassifierOutput,
  1085. config_class=_CONFIG_FOR_DOC,
  1086. )
  1087. def forward(
  1088. self,
  1089. input_ids: Optional[torch.Tensor] = None,
  1090. attention_mask: Optional[torch.Tensor] = None,
  1091. token_type_ids: Optional[torch.Tensor] = None,
  1092. inputs_embeds: Optional[torch.Tensor] = None,
  1093. labels: Optional[torch.Tensor] = None,
  1094. output_attentions: Optional[bool] = None,
  1095. output_hidden_states: Optional[bool] = None,
  1096. return_dict: Optional[bool] = None,
  1097. ) -> Union[Tuple, SequenceClassifierOutput]:
  1098. r"""
  1099. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1100. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1101. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1102. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1103. """
  1104. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1105. outputs = self.funnel(
  1106. input_ids,
  1107. attention_mask=attention_mask,
  1108. token_type_ids=token_type_ids,
  1109. inputs_embeds=inputs_embeds,
  1110. output_attentions=output_attentions,
  1111. output_hidden_states=output_hidden_states,
  1112. return_dict=return_dict,
  1113. )
  1114. last_hidden_state = outputs[0]
  1115. pooled_output = last_hidden_state[:, 0]
  1116. logits = self.classifier(pooled_output)
  1117. loss = None
  1118. if labels is not None:
  1119. if self.config.problem_type is None:
  1120. if self.num_labels == 1:
  1121. self.config.problem_type = "regression"
  1122. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1123. self.config.problem_type = "single_label_classification"
  1124. else:
  1125. self.config.problem_type = "multi_label_classification"
  1126. if self.config.problem_type == "regression":
  1127. loss_fct = MSELoss()
  1128. if self.num_labels == 1:
  1129. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1130. else:
  1131. loss = loss_fct(logits, labels)
  1132. elif self.config.problem_type == "single_label_classification":
  1133. loss_fct = CrossEntropyLoss()
  1134. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1135. elif self.config.problem_type == "multi_label_classification":
  1136. loss_fct = BCEWithLogitsLoss()
  1137. loss = loss_fct(logits, labels)
  1138. if not return_dict:
  1139. output = (logits,) + outputs[1:]
  1140. return ((loss,) + output) if loss is not None else output
  1141. return SequenceClassifierOutput(
  1142. loss=loss,
  1143. logits=logits,
  1144. hidden_states=outputs.hidden_states,
  1145. attentions=outputs.attentions,
  1146. )
  1147. @add_start_docstrings(
  1148. """
  1149. Funnel Transformer Model with a multiple choice classification head on top (two linear layer on top of the first
  1150. timestep of the last hidden state, and a softmax) e.g. for RocStories/SWAG tasks.
  1151. """,
  1152. FUNNEL_START_DOCSTRING,
  1153. )
  1154. class FunnelForMultipleChoice(FunnelPreTrainedModel):
  1155. def __init__(self, config: FunnelConfig) -> None:
  1156. super().__init__(config)
  1157. self.funnel = FunnelBaseModel(config)
  1158. self.classifier = FunnelClassificationHead(config, 1)
  1159. # Initialize weights and apply final processing
  1160. self.post_init()
  1161. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
  1162. @add_code_sample_docstrings(
  1163. checkpoint="funnel-transformer/small-base",
  1164. output_type=MultipleChoiceModelOutput,
  1165. config_class=_CONFIG_FOR_DOC,
  1166. )
  1167. def forward(
  1168. self,
  1169. input_ids: Optional[torch.Tensor] = None,
  1170. attention_mask: Optional[torch.Tensor] = None,
  1171. token_type_ids: Optional[torch.Tensor] = None,
  1172. inputs_embeds: Optional[torch.Tensor] = None,
  1173. labels: Optional[torch.Tensor] = None,
  1174. output_attentions: Optional[bool] = None,
  1175. output_hidden_states: Optional[bool] = None,
  1176. return_dict: Optional[bool] = None,
  1177. ) -> Union[Tuple, MultipleChoiceModelOutput]:
  1178. r"""
  1179. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1180. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1181. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1182. `input_ids` above)
  1183. """
  1184. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1185. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1186. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1187. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1188. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1189. inputs_embeds = (
  1190. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1191. if inputs_embeds is not None
  1192. else None
  1193. )
  1194. outputs = self.funnel(
  1195. input_ids,
  1196. attention_mask=attention_mask,
  1197. token_type_ids=token_type_ids,
  1198. inputs_embeds=inputs_embeds,
  1199. output_attentions=output_attentions,
  1200. output_hidden_states=output_hidden_states,
  1201. return_dict=return_dict,
  1202. )
  1203. last_hidden_state = outputs[0]
  1204. pooled_output = last_hidden_state[:, 0]
  1205. logits = self.classifier(pooled_output)
  1206. reshaped_logits = logits.view(-1, num_choices)
  1207. loss = None
  1208. if labels is not None:
  1209. loss_fct = CrossEntropyLoss()
  1210. loss = loss_fct(reshaped_logits, labels)
  1211. if not return_dict:
  1212. output = (reshaped_logits,) + outputs[1:]
  1213. return ((loss,) + output) if loss is not None else output
  1214. return MultipleChoiceModelOutput(
  1215. loss=loss,
  1216. logits=reshaped_logits,
  1217. hidden_states=outputs.hidden_states,
  1218. attentions=outputs.attentions,
  1219. )
  1220. @add_start_docstrings(
  1221. """
  1222. Funnel Transformer Model with a token classification head on top (a linear layer on top of the hidden-states
  1223. output) e.g. for Named-Entity-Recognition (NER) tasks.
  1224. """,
  1225. FUNNEL_START_DOCSTRING,
  1226. )
  1227. class FunnelForTokenClassification(FunnelPreTrainedModel):
  1228. def __init__(self, config: FunnelConfig) -> None:
  1229. super().__init__(config)
  1230. self.num_labels = config.num_labels
  1231. self.funnel = FunnelModel(config)
  1232. self.dropout = nn.Dropout(config.hidden_dropout)
  1233. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1234. # Initialize weights and apply final processing
  1235. self.post_init()
  1236. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1237. @add_code_sample_docstrings(
  1238. checkpoint=_CHECKPOINT_FOR_DOC,
  1239. output_type=TokenClassifierOutput,
  1240. config_class=_CONFIG_FOR_DOC,
  1241. )
  1242. def forward(
  1243. self,
  1244. input_ids: Optional[torch.Tensor] = None,
  1245. attention_mask: Optional[torch.Tensor] = None,
  1246. token_type_ids: Optional[torch.Tensor] = None,
  1247. inputs_embeds: Optional[torch.Tensor] = None,
  1248. labels: Optional[torch.Tensor] = None,
  1249. output_attentions: Optional[bool] = None,
  1250. output_hidden_states: Optional[bool] = None,
  1251. return_dict: Optional[bool] = None,
  1252. ) -> Union[Tuple, TokenClassifierOutput]:
  1253. r"""
  1254. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1255. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1256. """
  1257. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1258. outputs = self.funnel(
  1259. input_ids,
  1260. attention_mask=attention_mask,
  1261. token_type_ids=token_type_ids,
  1262. inputs_embeds=inputs_embeds,
  1263. output_attentions=output_attentions,
  1264. output_hidden_states=output_hidden_states,
  1265. return_dict=return_dict,
  1266. )
  1267. last_hidden_state = outputs[0]
  1268. last_hidden_state = self.dropout(last_hidden_state)
  1269. logits = self.classifier(last_hidden_state)
  1270. loss = None
  1271. if labels is not None:
  1272. loss_fct = CrossEntropyLoss()
  1273. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1274. if not return_dict:
  1275. output = (logits,) + outputs[1:]
  1276. return ((loss,) + output) if loss is not None else output
  1277. return TokenClassifierOutput(
  1278. loss=loss,
  1279. logits=logits,
  1280. hidden_states=outputs.hidden_states,
  1281. attentions=outputs.attentions,
  1282. )
  1283. @add_start_docstrings(
  1284. """
  1285. Funnel Transformer Model with a span classification head on top for extractive question-answering tasks like SQuAD
  1286. (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1287. """,
  1288. FUNNEL_START_DOCSTRING,
  1289. )
  1290. class FunnelForQuestionAnswering(FunnelPreTrainedModel):
  1291. def __init__(self, config: FunnelConfig) -> None:
  1292. super().__init__(config)
  1293. self.num_labels = config.num_labels
  1294. self.funnel = FunnelModel(config)
  1295. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1296. # Initialize weights and apply final processing
  1297. self.post_init()
  1298. @add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1299. @add_code_sample_docstrings(
  1300. checkpoint=_CHECKPOINT_FOR_DOC,
  1301. output_type=QuestionAnsweringModelOutput,
  1302. config_class=_CONFIG_FOR_DOC,
  1303. )
  1304. def forward(
  1305. self,
  1306. input_ids: Optional[torch.Tensor] = None,
  1307. attention_mask: Optional[torch.Tensor] = None,
  1308. token_type_ids: Optional[torch.Tensor] = None,
  1309. inputs_embeds: Optional[torch.Tensor] = None,
  1310. start_positions: Optional[torch.Tensor] = None,
  1311. end_positions: Optional[torch.Tensor] = None,
  1312. output_attentions: Optional[bool] = None,
  1313. output_hidden_states: Optional[bool] = None,
  1314. return_dict: Optional[bool] = None,
  1315. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1316. r"""
  1317. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1318. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1319. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1320. are not taken into account for computing the loss.
  1321. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1322. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1323. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1324. are not taken into account for computing the loss.
  1325. """
  1326. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1327. outputs = self.funnel(
  1328. input_ids,
  1329. attention_mask=attention_mask,
  1330. token_type_ids=token_type_ids,
  1331. inputs_embeds=inputs_embeds,
  1332. output_attentions=output_attentions,
  1333. output_hidden_states=output_hidden_states,
  1334. return_dict=return_dict,
  1335. )
  1336. last_hidden_state = outputs[0]
  1337. logits = self.qa_outputs(last_hidden_state)
  1338. start_logits, end_logits = logits.split(1, dim=-1)
  1339. start_logits = start_logits.squeeze(-1).contiguous()
  1340. end_logits = end_logits.squeeze(-1).contiguous()
  1341. total_loss = None
  1342. if start_positions is not None and end_positions is not None:
  1343. # If we are on multi-GPU, split add a dimension
  1344. if len(start_positions.size()) > 1:
  1345. start_positions = start_positions.squeze(-1)
  1346. if len(end_positions.size()) > 1:
  1347. end_positions = end_positions.squeeze(-1)
  1348. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1349. ignored_index = start_logits.size(1)
  1350. start_positions = start_positions.clamp(0, ignored_index)
  1351. end_positions = end_positions.clamp(0, ignored_index)
  1352. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1353. start_loss = loss_fct(start_logits, start_positions)
  1354. end_loss = loss_fct(end_logits, end_positions)
  1355. total_loss = (start_loss + end_loss) / 2
  1356. if not return_dict:
  1357. output = (start_logits, end_logits) + outputs[1:]
  1358. return ((total_loss,) + output) if total_loss is not None else output
  1359. return QuestionAnsweringModelOutput(
  1360. loss=total_loss,
  1361. start_logits=start_logits,
  1362. end_logits=end_logits,
  1363. hidden_states=outputs.hidden_states,
  1364. attentions=outputs.attentions,
  1365. )