| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389 |
- # coding=utf-8
- # Copyright 2020 Google Research and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch TAPAS model."""
- import enum
- import math
- import os
- from dataclasses import dataclass
- from typing import Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
- from ...activations import ACT2FN
- from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import (
- apply_chunking_to_forward,
- find_pruneable_heads_and_indices,
- is_torch_greater_or_equal_than_1_12,
- prune_linear_layer,
- )
- from ...utils import (
- ModelOutput,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from .configuration_tapas import TapasConfig
- logger = logging.get_logger(__name__)
- if not is_torch_greater_or_equal_than_1_12:
- logger.warning(
- f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
- "TapasModel. Please upgrade torch."
- )
- _CONFIG_FOR_DOC = "TapasConfig"
- _CHECKPOINT_FOR_DOC = "google/tapas-base"
- EPSILON_ZERO_DIVISION = 1e-10
- CLOSE_ENOUGH_TO_LOG_ZERO = -10000.0
- @dataclass
- class TableQuestionAnsweringOutput(ModelOutput):
- """
- Output type of [`TapasForQuestionAnswering`].
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` (and possibly `answer`, `aggregation_labels`, `numeric_values` and `numeric_values_scale` are provided)):
- Total loss as the sum of the hierarchical cell selection log-likelihood loss and (optionally) the
- semi-supervised regression loss and (optionally) supervised loss for aggregations.
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Prediction scores of the cell selection head, for every token.
- logits_aggregation (`torch.FloatTensor`, *optional*, of shape `(batch_size, num_aggregation_labels)`):
- Prediction scores of the aggregation head, for every aggregation operator.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
- shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
- plus the initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
- the self-attention heads.
- """
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- logits_aggregation: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- def load_tf_weights_in_tapas(model, config, tf_checkpoint_path):
- """
- Load tf checkpoints in a PyTorch model. This is an adaptation from load_tf_weights_in_bert
- - add cell selection and aggregation heads
- - take into account additional token type embedding layers
- """
- try:
- import re
- import numpy as np
- import tensorflow as tf
- except ImportError:
- logger.error(
- "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
- "https://www.tensorflow.org/install/ for installation instructions."
- )
- raise
- tf_path = os.path.abspath(tf_checkpoint_path)
- logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
- # Load weights from TF model
- init_vars = tf.train.list_variables(tf_path)
- names = []
- arrays = []
- for name, shape in init_vars:
- logger.info(f"Loading TF weight {name} with shape {shape}")
- array = tf.train.load_variable(tf_path, name)
- names.append(name)
- arrays.append(array)
- for name, array in zip(names, arrays):
- name = name.split("/")
- # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculate m and v
- # which are not required for using pretrained model
- if any(
- n
- in [
- "adam_v",
- "adam_m",
- "AdamWeightDecayOptimizer",
- "AdamWeightDecayOptimizer_1",
- "global_step",
- "seq_relationship",
- ]
- for n in name
- ):
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- # in case the model is TapasForSequenceClassification, we skip output_bias and output_weights
- # since these are not used for classification
- if isinstance(model, TapasForSequenceClassification):
- if any(n in ["output_bias", "output_weights"] for n in name):
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- # in case the model is TapasModel, we skip output_bias, output_weights, output_bias_cls and output_weights_cls
- # since this model does not have MLM and NSP heads
- if isinstance(model, TapasModel):
- if any(n in ["output_bias", "output_weights", "output_bias_cls", "output_weights_cls"] for n in name):
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- # in case the model is TapasForMaskedLM, we skip the pooler
- if isinstance(model, TapasForMaskedLM):
- if any(n in ["pooler"] for n in name):
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- # if first scope name starts with "bert", change it to "tapas"
- if name[0] == "bert":
- name[0] = "tapas"
- pointer = model
- for m_name in name:
- if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
- scope_names = re.split(r"_(\d+)", m_name)
- else:
- scope_names = [m_name]
- if scope_names[0] == "kernel" or scope_names[0] == "gamma":
- pointer = getattr(pointer, "weight")
- elif scope_names[0] == "beta":
- pointer = getattr(pointer, "bias")
- # cell selection heads
- elif scope_names[0] == "output_bias":
- if not isinstance(model, TapasForMaskedLM):
- pointer = getattr(pointer, "output_bias")
- else:
- pointer = getattr(pointer, "bias")
- elif scope_names[0] == "output_weights":
- pointer = getattr(pointer, "output_weights")
- elif scope_names[0] == "column_output_bias":
- pointer = getattr(pointer, "column_output_bias")
- elif scope_names[0] == "column_output_weights":
- pointer = getattr(pointer, "column_output_weights")
- # aggregation head
- elif scope_names[0] == "output_bias_agg":
- pointer = getattr(pointer, "aggregation_classifier")
- pointer = getattr(pointer, "bias")
- elif scope_names[0] == "output_weights_agg":
- pointer = getattr(pointer, "aggregation_classifier")
- pointer = getattr(pointer, "weight")
- # classification head
- elif scope_names[0] == "output_bias_cls":
- pointer = getattr(pointer, "classifier")
- pointer = getattr(pointer, "bias")
- elif scope_names[0] == "output_weights_cls":
- pointer = getattr(pointer, "classifier")
- pointer = getattr(pointer, "weight")
- else:
- try:
- pointer = getattr(pointer, scope_names[0])
- except AttributeError:
- logger.info(f"Skipping {'/'.join(name)}")
- continue
- if len(scope_names) >= 2:
- num = int(scope_names[1])
- pointer = pointer[num]
- if m_name[-11:] == "_embeddings":
- pointer = getattr(pointer, "weight")
- elif m_name[-13:] in [f"_embeddings_{i}" for i in range(7)]:
- pointer = getattr(pointer, "weight")
- elif m_name == "kernel":
- array = np.transpose(array)
- try:
- if pointer.shape != array.shape:
- raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
- except AssertionError as e:
- e.args += (pointer.shape, array.shape)
- raise
- logger.info(f"Initialize PyTorch weight {name}")
- # Added a check to see whether the array is a scalar (because bias terms in Tapas checkpoints can be
- # scalar => should first be converted to numpy arrays)
- if np.isscalar(array):
- array = np.array(array)
- pointer.data = torch.from_numpy(array)
- return model
- class TapasEmbeddings(nn.Module):
- """
- Construct the embeddings from word, position and token_type embeddings. Same as BertEmbeddings but with a number of
- additional token type embeddings to encode tabular structure.
- """
- def __init__(self, config):
- super().__init__()
- # we do not include config.disabled_features and config.disable_position_embeddings from the original implementation
- # word embeddings
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- # position embeddings
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- # token type embeddings
- for i, type_vocab_sizes in enumerate(config.type_vocab_sizes):
- name = f"token_type_embeddings_{i}"
- setattr(self, name, nn.Embedding(type_vocab_sizes, config.hidden_size))
- self.number_of_token_type_embeddings = len(config.type_vocab_sizes)
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.config = config
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if position_ids is None:
- # create absolute position embeddings
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
- position_ids = position_ids.unsqueeze(0).expand(input_shape)
- # when self.config.reset_position_index_per_cell is set to True, create relative position embeddings
- if self.config.reset_position_index_per_cell:
- # shape (batch_size, seq_len)
- col_index = IndexMap(token_type_ids[:, :, 1], self.config.type_vocab_sizes[1], batch_dims=1)
- # shape (batch_size, seq_len)
- row_index = IndexMap(token_type_ids[:, :, 2], self.config.type_vocab_sizes[2], batch_dims=1)
- # shape (batch_size, seq_len)
- full_index = ProductIndexMap(col_index, row_index)
- # shape (max_rows * max_columns,). First absolute position for every cell
- first_position_per_segment = reduce_min(position_ids, full_index)[0]
- # ? shape (batch_size, seq_len). First absolute position of the cell for every token
- first_position = gather(first_position_per_segment, full_index)
- # shape (1, seq_len)
- position = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0)
- position_ids = torch.min(
- torch.as_tensor(self.config.max_position_embeddings - 1, device=device), position - first_position
- )
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (input_shape + self.number_of_token_type_embeddings), dtype=torch.long, device=device
- )
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = inputs_embeds + position_embeddings
- for i in range(self.number_of_token_type_embeddings):
- name = f"token_type_embeddings_{i}"
- embeddings += getattr(self, name)(token_type_ids[:, :, i])
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class TapasSelfAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
- f"heads {config.num_attention_heads}"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.is_decoder = config.is_decoder
- def transpose_for_scores(self, x):
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
- return x.permute(0, 2, 1, 3)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_value=None,
- output_attentions=False,
- ):
- mixed_query_layer = self.query(hidden_states)
- # If this is instantiated as a cross-attention module, the keys
- # and values come from an encoder; the attention mask needs to be
- # such that the encoder's padding tokens are not attended to.
- is_cross_attention = encoder_hidden_states is not None
- if is_cross_attention and past_key_value is not None:
- # reuse k,v, cross_attentions
- key_layer = past_key_value[0]
- value_layer = past_key_value[1]
- attention_mask = encoder_attention_mask
- elif is_cross_attention:
- key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
- value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
- attention_mask = encoder_attention_mask
- elif past_key_value is not None:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
- value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
- else:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- query_layer = self.transpose_for_scores(mixed_query_layer)
- if self.is_decoder:
- past_key_value = (key_layer, value_layer)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in TapasModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs = attention_probs * head_mask
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- if self.is_decoder:
- outputs = outputs + (past_key_value,)
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
- class TapasSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TapasAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.self = TapasSelfAttention(config)
- self.output = TapasSelfOutput(config)
- self.pruned_heads = set()
- # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
- def prune_heads(self, heads):
- if len(heads) == 0:
- return
- heads, index = find_pruneable_heads_and_indices(
- heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
- )
- # Prune linear layers
- self.self.query = prune_linear_layer(self.self.query, index)
- self.self.key = prune_linear_layer(self.self.key, index)
- self.self.value = prune_linear_layer(self.self.value, index)
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
- # Update hyper params and store pruned heads
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
- self.pruned_heads = self.pruned_heads.union(heads)
- # Copied from transformers.models.bert.modeling_bert.BertAttention.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask,
- head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_value,
- output_attentions,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate
- class TapasIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOutput
- class TapasOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class TapasLayer(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = TapasAttention(config)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = TapasAttention(config)
- self.intermediate = TapasIntermediate(config)
- self.output = TapasOutput(config)
- # Copied from transformers.models.bert.modeling_bert.BertLayer.forward
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask,
- head_mask,
- output_attentions=output_attentions,
- past_key_value=self_attn_past_key_value,
- )
- attention_output = self_attention_outputs[0]
- # if decoder, the last output is tuple of self-attn cache
- if self.is_decoder:
- outputs = self_attention_outputs[1:-1]
- present_key_value = self_attention_outputs[-1]
- else:
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- cross_attn_present_key_value = None
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise ValueError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
- " by setting `config.add_cross_attention=True`"
- )
- # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
- cross_attention_outputs = self.crossattention(
- attention_output,
- attention_mask,
- head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- cross_attn_past_key_value,
- output_attentions,
- )
- attention_output = cross_attention_outputs[0]
- outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
- # add cross-attn cache to positions 3,4 of present_key_value tuple
- cross_attn_present_key_value = cross_attention_outputs[-1]
- present_key_value = present_key_value + cross_attn_present_key_value
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- outputs = (layer_output,) + outputs
- # if decoder, return the attn key/values as the last output
- if self.is_decoder:
- outputs = outputs + (present_key_value,)
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class TapasEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- past_key_values=None,
- use_cache=None,
- output_attentions=False,
- output_hidden_states=False,
- return_dict=True,
- ):
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_head_mask = head_mask[i] if head_mask is not None else None
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- layer_module.__call__,
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_values,
- output_attentions,
- )
- else:
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- past_key_values,
- output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
- )
- # Copied from transformers.models.bert.modeling_bert.BertPooler
- class TapasPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->Tapas
- class TapasPredictionHeadTransform(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- if isinstance(config.hidden_act, str):
- self.transform_act_fn = ACT2FN[config.hidden_act]
- else:
- self.transform_act_fn = config.hidden_act
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.transform_act_fn(hidden_states)
- hidden_states = self.LayerNorm(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->Tapas
- class TapasLMPredictionHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.transform = TapasPredictionHeadTransform(config)
- # The output weights are the same as the input embeddings, but there is
- # an output-only bias for each token.
- self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- self.bias = nn.Parameter(torch.zeros(config.vocab_size))
- # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
- self.decoder.bias = self.bias
- def _tie_weights(self):
- self.decoder.bias = self.bias
- def forward(self, hidden_states):
- hidden_states = self.transform(hidden_states)
- hidden_states = self.decoder(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Tapas
- class TapasOnlyMLMHead(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.predictions = TapasLMPredictionHead(config)
- def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
- prediction_scores = self.predictions(sequence_output)
- return prediction_scores
- class TapasPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = TapasConfig
- base_model_prefix = "tapas"
- supports_gradient_checkpointing = True
- _supports_param_buffer_assignment = False
- # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, nn.Linear):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- TAPAS_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`TapasConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- TAPAS_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`torch.LongTensor` of shape `({0}, 7)`, *optional*):
- Token indices that encode tabular structure. Indices can be obtained using [`AutoTokenizer`]. See this
- class for more info.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. If
- `reset_position_index_per_cell` of [`TapasConfig`] is set to `True`, relative position embeddings will be
- used. Selected in the range `[0, config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - 1
- indicates the head is **not masked**, - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "The bare Tapas Model transformer outputting raw hidden-states without any specific head on top.",
- TAPAS_START_DOCSTRING,
- )
- class TapasModel(TapasPreTrainedModel):
- """
- This class is a small change compared to [`BertModel`], taking into account the additional token type ids.
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
- cross-attention is added between the self-attention layers, following the architecture described in [Attention is
- all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
- Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
- """
- def __init__(self, config, add_pooling_layer=True):
- super().__init__(config)
- self.config = config
- self.embeddings = TapasEmbeddings(config)
- self.encoder = TapasEncoder(config)
- self.pooler = TapasPooler(config) if add_pooling_layer else None
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- def _prune_heads(self, heads_to_prune):
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- for layer, heads in heads_to_prune.items():
- self.encoder.layer[layer].attention.prune_heads(heads)
- @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
- r"""
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasModel
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
- >>> model = TapasModel.from_pretrained("google/tapas-base")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_states = outputs.last_hidden_state
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
- )
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
- # ourselves in which case we just need to make it broadcastable to all heads.
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
- # If a 2D ou 3D attention mask is provided for the cross-attention
- # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
- if self.config.is_decoder and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
- embedding_output = self.embeddings(
- input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
- )
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=extended_attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
- if not return_dict:
- return (sequence_output, pooled_output) + encoder_outputs[1:]
- return BaseModelOutputWithPooling(
- last_hidden_state=sequence_output,
- pooler_output=pooled_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
- class TapasForMaskedLM(TapasPreTrainedModel):
- _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
- config_class = TapasConfig
- base_model_prefix = "tapas"
- def __init__(self, config):
- super().__init__(config)
- self.tapas = TapasModel(config, add_pooling_layer=False)
- self.cls = TapasOnlyMLMHead(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.cls.predictions.decoder
- def set_output_embeddings(self, new_embeddings):
- self.cls.predictions.decoder = new_embeddings
- self.cls.predictions.bias = new_embeddings.bias
- @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[Tuple, MaskedLMOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
- config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
- loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForMaskedLM
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base")
- >>> model = TapasForMaskedLM.from_pretrained("google/tapas-base")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> inputs = tokenizer(
- ... table=table, queries="How many [MASK] has George [MASK] played in?", return_tensors="pt"
- ... )
- >>> labels = tokenizer(
- ... table=table, queries="How many movies has George Clooney played in?", return_tensors="pt"
- ... )["input_ids"]
- >>> outputs = model(**inputs, labels=labels)
- >>> logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- prediction_scores = self.cls(sequence_output)
- masked_lm_loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss() # -100 index = padding token
- masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (prediction_scores,) + outputs[2:]
- return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
- return MaskedLMOutput(
- loss=masked_lm_loss,
- logits=prediction_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Tapas Model with a cell selection head and optional aggregation head on top for question-answering tasks on tables
- (linear layers on top of the hidden-states output to compute `logits` and optional `logits_aggregation`), e.g. for
- SQA, WTQ or WikiSQL-supervised tasks.
- """,
- TAPAS_START_DOCSTRING,
- )
- class TapasForQuestionAnswering(TapasPreTrainedModel):
- def __init__(self, config: TapasConfig):
- super().__init__(config)
- # base model
- self.tapas = TapasModel(config)
- # dropout (only used when training)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # cell selection heads
- if config.init_cell_selection_weights_to_zero:
- # init_cell_selection_weights_to_zero: Whether the initial weights should be
- # set to 0. This ensures that all tokens have the same prior probability.
- self.output_weights = nn.Parameter(torch.zeros(config.hidden_size))
- self.column_output_weights = nn.Parameter(torch.zeros(config.hidden_size))
- else:
- self.output_weights = nn.Parameter(torch.empty(config.hidden_size))
- nn.init.normal_(
- self.output_weights, std=config.initializer_range
- ) # here, a truncated normal is used in the original implementation
- self.column_output_weights = nn.Parameter(torch.empty(config.hidden_size))
- nn.init.normal_(
- self.column_output_weights, std=config.initializer_range
- ) # here, a truncated normal is used in the original implementation
- self.output_bias = nn.Parameter(torch.zeros([]))
- self.column_output_bias = nn.Parameter(torch.zeros([]))
- # aggregation head
- if config.num_aggregation_labels > 0:
- self.aggregation_classifier = nn.Linear(config.hidden_size, config.num_aggregation_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=TableQuestionAnsweringOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- table_mask: Optional[torch.LongTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- aggregation_labels: Optional[torch.LongTensor] = None,
- float_answer: Optional[torch.FloatTensor] = None,
- numeric_values: Optional[torch.FloatTensor] = None,
- numeric_values_scale: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TableQuestionAnsweringOutput]:
- r"""
- table_mask (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
- Mask for the table. Indicates which tokens belong to the table (1). Question tokens, table headers and
- padding are 0.
- labels (`torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
- Labels per token for computing the hierarchical cell selection loss. This encodes the positions of the
- answer appearing in the table. Can be obtained using [`AutoTokenizer`].
- - 1 for tokens that are **part of the answer**,
- - 0 for tokens that are **not part of the answer**.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
- Aggregation function index for every example in the batch for computing the aggregation loss. Indices
- should be in `[0, ..., config.num_aggregation_labels - 1]`. Only required in case of strong supervision for
- aggregation (WikiSQL-supervised).
- float_answer (`torch.FloatTensor` of shape `(batch_size, )`, *optional*):
- Float answer for every example in the batch. Set to *float('nan')* for cell selection questions. Only
- required in case of weak supervision (WTQ) to calculate the aggregate mask and regression loss.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
- Numeric values of every token, NaN for tokens which are not numeric values. Can be obtained using
- [`AutoTokenizer`]. Only required in case of weak supervision for aggregation (WTQ) to calculate the
- regression loss.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`, *optional*):
- Scale of the numeric values of every token. Can be obtained using [`AutoTokenizer`]. Only required in case
- of weak supervision for aggregation (WTQ) to calculate the regression loss.
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForQuestionAnswering
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-wtq")
- >>> model = TapasForQuestionAnswering.from_pretrained("google/tapas-base-finetuned-wtq")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = ["How many movies has George Clooney played in?", "How old is Brad Pitt?"]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> logits = outputs.logits
- >>> logits_aggregation = outputs.logits_aggregation
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- pooled_output = outputs[1]
- sequence_output = self.dropout(sequence_output)
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- # Construct indices for the table.
- if token_type_ids is None:
- token_type_ids = torch.zeros(
- (*input_shape, len(self.config.type_vocab_sizes)), dtype=torch.long, device=device
- )
- token_types = [
- "segment_ids",
- "column_ids",
- "row_ids",
- "prev_labels",
- "column_ranks",
- "inv_column_ranks",
- "numeric_relations",
- ]
- row_ids = token_type_ids[:, :, token_types.index("row_ids")]
- column_ids = token_type_ids[:, :, token_types.index("column_ids")]
- row_index = IndexMap(
- indices=torch.min(row_ids, torch.as_tensor(self.config.max_num_rows - 1, device=row_ids.device)),
- num_segments=self.config.max_num_rows,
- batch_dims=1,
- )
- col_index = IndexMap(
- indices=torch.min(column_ids, torch.as_tensor(self.config.max_num_columns - 1, device=column_ids.device)),
- num_segments=self.config.max_num_columns,
- batch_dims=1,
- )
- cell_index = ProductIndexMap(row_index, col_index)
- # Masks.
- input_shape = input_ids.size() if input_ids is not None else inputs_embeds.size()[:-1]
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- if attention_mask is None:
- attention_mask = torch.ones(input_shape, device=device)
- # Table cells only, without question tokens and table headers.
- if table_mask is None:
- table_mask = torch.where(row_ids > 0, torch.ones_like(row_ids), torch.zeros_like(row_ids))
- # torch.FloatTensor[batch_size, seq_length]
- input_mask_float = attention_mask.float().to(device)
- table_mask_float = table_mask.float().to(device)
- # Mask for cells that exist in the table (i.e. that are not padding).
- cell_mask, _ = reduce_mean(input_mask_float, cell_index)
- # Compute logits per token. These are used to select individual cells.
- logits = compute_token_logits(sequence_output, self.config.temperature, self.output_weights, self.output_bias)
- # Compute logits per column. These are used to select a column.
- column_logits = None
- if self.config.select_one_column:
- column_logits = compute_column_logits(
- sequence_output,
- self.column_output_weights,
- self.column_output_bias,
- cell_index,
- cell_mask,
- self.config.allow_empty_column_selection,
- )
- # Aggregation logits
- logits_aggregation = None
- if self.config.num_aggregation_labels > 0:
- logits_aggregation = self.aggregation_classifier(pooled_output)
- # Total loss calculation
- total_loss = 0.0
- calculate_loss = False
- if labels is not None:
- calculate_loss = True
- is_supervised = not self.config.num_aggregation_labels > 0 or not self.config.use_answer_as_supervision
- # Semi-supervised cell selection in case of no aggregation:
- # If the answer (the denotation) appears directly in the table we might
- # select the answer without applying any aggregation function. There are
- # some ambiguous cases, see utils._calculate_aggregate_mask for more info.
- # `aggregate_mask` is 1 for examples where we chose to aggregate and 0
- # for examples where we chose to select the answer directly.
- # `labels` encodes the positions of the answer appearing in the table.
- if is_supervised:
- aggregate_mask = None
- else:
- if float_answer is not None:
- assert (
- labels.shape[0] == float_answer.shape[0]
- ), "Make sure the answers are a FloatTensor of shape (batch_size,)"
- # <float32>[batch_size]
- aggregate_mask = _calculate_aggregate_mask(
- float_answer,
- pooled_output,
- self.config.cell_selection_preference,
- labels,
- self.aggregation_classifier,
- )
- else:
- raise ValueError("You have to specify float answers in order to calculate the aggregate mask")
- # Cell selection log-likelihood
- if self.config.average_logits_per_cell:
- logits_per_cell, _ = reduce_mean(logits, cell_index)
- logits = gather(logits_per_cell, cell_index)
- dist_per_token = torch.distributions.Bernoulli(logits=logits)
- # Compute cell selection loss per example.
- selection_loss_per_example = None
- if not self.config.select_one_column:
- weight = torch.where(
- labels == 0,
- torch.ones_like(labels, dtype=torch.float32),
- self.config.positive_label_weight * torch.ones_like(labels, dtype=torch.float32),
- )
- selection_loss_per_token = -dist_per_token.log_prob(labels) * weight
- selection_loss_per_example = torch.sum(selection_loss_per_token * input_mask_float, dim=1) / (
- torch.sum(input_mask_float, dim=1) + EPSILON_ZERO_DIVISION
- )
- else:
- selection_loss_per_example, logits = _single_column_cell_selection_loss(
- logits, column_logits, labels, cell_index, col_index, cell_mask
- )
- dist_per_token = torch.distributions.Bernoulli(logits=logits)
- # Supervised cell selection
- if self.config.disable_per_token_loss:
- pass
- elif is_supervised:
- total_loss += torch.mean(selection_loss_per_example)
- else:
- # For the not supervised case, do not assign loss for cell selection
- total_loss += torch.mean(selection_loss_per_example * (1.0 - aggregate_mask))
- # Semi-supervised regression loss and supervised loss for aggregations
- if self.config.num_aggregation_labels > 0:
- if is_supervised:
- # Note that `aggregate_mask` is None if the setting is supervised.
- if aggregation_labels is not None:
- assert (
- labels.shape[0] == aggregation_labels.shape[0]
- ), "Make sure the aggregation labels are a LongTensor of shape (batch_size,)"
- per_example_additional_loss = _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- self.config.use_answer_as_supervision,
- self.config.num_aggregation_labels,
- self.config.aggregation_loss_weight,
- )
- else:
- raise ValueError(
- "You have to specify aggregation labels in order to calculate the aggregation loss"
- )
- else:
- # Set aggregation labels to zeros
- aggregation_labels = torch.zeros(labels.shape[0], dtype=torch.long, device=labels.device)
- per_example_additional_loss = _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- self.config.use_answer_as_supervision,
- self.config.num_aggregation_labels,
- self.config.aggregation_loss_weight,
- )
- if self.config.use_answer_as_supervision:
- if numeric_values is not None and numeric_values_scale is not None:
- assert numeric_values.shape == numeric_values_scale.shape
- # Add regression loss for numeric answers which require aggregation.
- answer_loss, large_answer_loss_mask = _calculate_regression_loss(
- float_answer,
- aggregate_mask,
- dist_per_token,
- numeric_values,
- numeric_values_scale,
- table_mask_float,
- logits_aggregation,
- self.config,
- )
- per_example_additional_loss += answer_loss
- # Zero loss for examples with answer_loss > cutoff.
- per_example_additional_loss *= large_answer_loss_mask
- else:
- raise ValueError(
- "You have to specify numeric values and numeric values scale in order to calculate the"
- " regression loss"
- )
- total_loss += torch.mean(per_example_additional_loss)
- else:
- # if no label ids are provided, set them to zeros in order to properly compute logits
- labels = torch.zeros_like(logits)
- _, logits = _single_column_cell_selection_loss(
- logits, column_logits, labels, cell_index, col_index, cell_mask
- )
- if not return_dict:
- output = (logits, logits_aggregation) + outputs[2:]
- return ((total_loss,) + output) if calculate_loss else output
- return TableQuestionAnsweringOutput(
- loss=total_loss if calculate_loss else None,
- logits=logits,
- logits_aggregation=logits_aggregation,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Tapas Model with a sequence classification head on top (a linear layer on top of the pooled output), e.g. for table
- entailment tasks, such as TabFact (Chen et al., 2020).
- """,
- TAPAS_START_DOCSTRING,
- )
- class TapasForSequenceClassification(TapasPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.tapas = TapasModel(config)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- @add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy). Note: this is called
- "classification_class_index" in the original implementation.
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoTokenizer, TapasForSequenceClassification
- >>> import torch
- >>> import pandas as pd
- >>> tokenizer = AutoTokenizer.from_pretrained("google/tapas-base-finetuned-tabfact")
- >>> model = TapasForSequenceClassification.from_pretrained("google/tapas-base-finetuned-tabfact")
- >>> data = {
- ... "Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
- ... "Age": ["56", "45", "59"],
- ... "Number of movies": ["87", "53", "69"],
- ... }
- >>> table = pd.DataFrame.from_dict(data)
- >>> queries = [
- ... "There is only one actor who is 45 years old",
- ... "There are 3 actors which played in more than 60 movies",
- ... ]
- >>> inputs = tokenizer(table=table, queries=queries, padding="max_length", return_tensors="pt")
- >>> labels = torch.tensor([1, 0]) # 1 means entailed, 0 means refuted
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> logits = outputs.logits
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.tapas(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output)
- logits = self.classifier(pooled_output)
- loss = None
- if labels is not None:
- if self.config.problem_type is None:
- if self.num_labels == 1:
- self.config.problem_type = "regression"
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
- self.config.problem_type = "single_label_classification"
- else:
- self.config.problem_type = "multi_label_classification"
- if self.config.problem_type == "regression":
- loss_fct = MSELoss()
- if self.num_labels == 1:
- loss = loss_fct(logits.squeeze(), labels.squeeze())
- else:
- loss = loss_fct(logits, labels)
- elif self.config.problem_type == "single_label_classification":
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
- elif self.config.problem_type == "multi_label_classification":
- loss_fct = BCEWithLogitsLoss()
- loss = loss_fct(logits, labels)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- """ TAPAS utilities."""
- class AverageApproximationFunction(str, enum.Enum):
- RATIO = "ratio"
- FIRST_ORDER = "first_order"
- SECOND_ORDER = "second_order"
- # Beginning of everything related to segmented tensors
- class IndexMap:
- """Index grouping entries within a tensor."""
- def __init__(self, indices, num_segments, batch_dims=0):
- """
- Creates an index
- Args:
- indices (`torch.LongTensor`, same shape as a *values* Tensor to which the indices refer):
- Tensor containing the indices.
- num_segments (`torch.LongTensor`):
- Scalar tensor, the number of segments. All elements in a batched segmented tensor must have the same
- number of segments (although many segments can be empty).
- batch_dims (`int`, *optional*, defaults to 0):
- The number of batch dimensions. The first *batch_dims* dimensions of a SegmentedTensor are treated as
- batch dimensions. Segments in different batch elements are always distinct even if they have the same
- index.
- """
- self.indices = torch.as_tensor(indices)
- self.num_segments = torch.as_tensor(num_segments, device=indices.device)
- self.batch_dims = batch_dims
- def batch_shape(self):
- return self.indices.size()[: self.batch_dims] # returns a torch.Size object
- class ProductIndexMap(IndexMap):
- """The product of two indices."""
- def __init__(self, outer_index, inner_index):
- """
- Combines indices i and j into pairs (i, j). The result is an index where each segment (i, j) is the
- intersection of segments i and j. For example if the inputs represent table cells indexed by respectively rows
- and columns the output will be a table indexed by (row, column) pairs, i.e. by cell. The implementation
- combines indices {0, .., n - 1} and {0, .., m - 1} into {0, .., nm - 1}. The output has *num_segments* equal to
- *outer_index.num_segments* * *inner_index.num_segments*
- Args:
- outer_index (`IndexMap`):
- IndexMap.
- inner_index (`IndexMap`):
- IndexMap, must have the same shape as *outer_index*.
- """
- if outer_index.batch_dims != inner_index.batch_dims:
- raise ValueError("outer_index.batch_dims and inner_index.batch_dims must be the same.")
- super().__init__(
- indices=(inner_index.indices + outer_index.indices * inner_index.num_segments),
- num_segments=inner_index.num_segments * outer_index.num_segments,
- batch_dims=inner_index.batch_dims,
- )
- self.outer_index = outer_index
- self.inner_index = inner_index
- def project_outer(self, index):
- """Projects an index with the same index set onto the outer components."""
- indices = torch.div(index.indices, self.inner_index.num_segments, rounding_mode="floor").type(torch.long)
- return IndexMap(indices=indices, num_segments=self.outer_index.num_segments, batch_dims=index.batch_dims)
- def project_inner(self, index):
- """Projects an index with the same index set onto the inner components."""
- return IndexMap(
- indices=torch.fmod(index.indices, self.inner_index.num_segments)
- .type(torch.float)
- .floor()
- .type(torch.long),
- num_segments=self.inner_index.num_segments,
- batch_dims=index.batch_dims,
- )
- def gather(values, index, name="segmented_gather"):
- """
- Gathers from *values* using the index map. For each element in the domain of the index map this operation looks up
- a value for that index in *values*. Two elements from the same segment always get assigned the same value.
- Args:
- values (`torch.Tensor` of shape (B1, ..., Bn, num_segments, V1, ...)):
- Tensor with segment values.
- index (`IndexMap` of shape (B1, ..., Bn, I1, ..., Ik)):
- IndexMap.
- name (`str`, *optional*, defaults to 'segmented_gather'):
- Name for the operation. Currently not used
- Returns:
- `tuple(torch.Tensor)`: Tensor of shape (B1, ..., Bn, I1, ..., Ik, V1, ...) with the gathered values.
- """
- indices = index.indices
- # first, check whether the indices of the index represent scalar values (i.e. not vectorized)
- if len(values.shape[index.batch_dims :]) < 2:
- return torch.gather(
- values,
- index.batch_dims,
- indices.view(
- values.size()[0], -1
- ), # torch.gather expects index to have the same number of dimensions as values
- ).view(indices.size())
- else:
- # this means we have a vectorized version
- # we have to adjust the index
- indices = indices.unsqueeze(-1).expand(values.shape)
- return torch.gather(values, index.batch_dims, indices)
- def flatten(index, name="segmented_flatten"):
- """
- Flattens a batched index map (which is typically of shape batch_size, seq_length) to a 1d index map. This operation
- relabels the segments to keep batch elements distinct. The k-th batch element will have indices shifted by
- *num_segments* * (k - 1). The result is a tensor with *num_segments* multiplied by the number of elements in the
- batch.
- Args:
- index (`IndexMap`):
- IndexMap to flatten.
- name (`str`, *optional*, defaults to 'segmented_flatten'):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): The flattened IndexMap.
- """
- # first, get batch_size as scalar tensor
- batch_size = torch.prod(torch.tensor(list(index.batch_shape())))
- # next, create offset as 1-D tensor of length batch_size,
- # and multiply element-wise by num segments (to offset different elements in the batch) e.g. if batch size is 2: [0, 64]
- offset = torch.arange(start=0, end=batch_size, device=index.num_segments.device) * index.num_segments
- offset = offset.view(index.batch_shape())
- for _ in range(index.batch_dims, len(index.indices.size())): # typically range(1,2)
- offset = offset.unsqueeze(-1)
- indices = offset + index.indices
- return IndexMap(indices=indices.view(-1), num_segments=index.num_segments * batch_size, batch_dims=0)
- def range_index_map(batch_shape, num_segments, name="range_index_map"):
- """
- Constructs an index map equal to range(num_segments).
- Args:
- batch_shape (`torch.Size`):
- Batch shape
- num_segments (`int`):
- Number of segments
- name (`str`, *optional*, defaults to 'range_index_map'):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
- """
- batch_shape = torch.as_tensor(
- batch_shape, dtype=torch.long
- ) # create a rank 1 tensor vector containing batch_shape (e.g. [2])
- assert len(batch_shape.size()) == 1
- num_segments = torch.as_tensor(num_segments) # create a rank 0 tensor (scalar) containing num_segments (e.g. 64)
- assert len(num_segments.size()) == 0
- indices = torch.arange(
- start=0, end=num_segments, device=num_segments.device
- ) # create a rank 1 vector with num_segments elements
- new_tensor = torch.cat(
- [torch.ones_like(batch_shape, dtype=torch.long, device=num_segments.device), num_segments.unsqueeze(dim=0)],
- dim=0,
- )
- # new_tensor is just a vector of [1 64] for example (assuming only 1 batch dimension)
- new_shape = [int(x) for x in new_tensor.tolist()]
- indices = indices.view(new_shape)
- multiples = torch.cat([batch_shape, torch.as_tensor([1])], dim=0)
- indices = indices.repeat(multiples.tolist())
- # equivalent (in Numpy:)
- # indices = torch.as_tensor(np.tile(indices.numpy(), multiples.tolist()))
- return IndexMap(indices=indices, num_segments=num_segments, batch_dims=list(batch_shape.size())[0])
- def _segment_reduce(values, index, segment_reduce_fn, name):
- """
- Applies a segment reduction segment-wise.
- Args:
- values (`torch.Tensor`):
- Tensor with segment values.
- index (`IndexMap`):
- IndexMap.
- segment_reduce_fn (`str`):
- Name for the reduce operation. One of "sum", "mean", "max" or "min".
- name (`str`):
- Name for the operation. Currently not used
- Returns:
- (`IndexMap`): IndexMap of shape batch_shape with elements equal to range(num_segments).
- """
- # Flatten the batch dimensions, as segments ops (scatter) do not support batching.
- # However if `values` has extra dimensions to the right keep them
- # unflattened. Segmented ops support vector-valued operations.
- flat_index = flatten(index)
- vector_shape = values.size()[len(index.indices.size()) :] # torch.Size object
- flattened_shape = torch.cat(
- [torch.as_tensor([-1], dtype=torch.long), torch.as_tensor(vector_shape, dtype=torch.long)], dim=0
- )
- # changed "view" by "reshape" in the following line
- flat_values = values.reshape(flattened_shape.tolist())
- out = torch.zeros(int(flat_index.num_segments), dtype=torch.float, device=flat_values.device)
- segment_means = out.scatter_reduce(
- dim=0, index=flat_index.indices.long(), src=flat_values.float(), reduce=segment_reduce_fn, include_self=False
- )
- # Unflatten the values.
- new_shape = torch.cat(
- [
- torch.as_tensor(index.batch_shape(), dtype=torch.long),
- torch.as_tensor([index.num_segments], dtype=torch.long),
- torch.as_tensor(vector_shape, dtype=torch.long),
- ],
- dim=0,
- )
- output_values = segment_means.clone().view(new_shape.tolist()).to(values.dtype)
- output_index = range_index_map(index.batch_shape(), index.num_segments)
- return output_values, output_index
- def reduce_sum(values, index, name="segmented_reduce_sum"):
- """
- Sums a tensor over its segments.
- Outputs 0 for empty segments.
- This operations computes the sum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a sum of
- vectors rather than scalars. Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the sum must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments]. .
- """
- return _segment_reduce(values, index, "sum", name)
- def reduce_mean(values, index, name="segmented_reduce_mean"):
- """
- Averages a tensor over its segments.
- Outputs 0 for empty segments.
- This operations computes the mean over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be a mean of
- vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the mean must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "mean", name)
- def reduce_max(values, index, name="segmented_reduce_max"):
- """
- Computes the maximum over segments.
- This operation computes the maximum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
- maximum of vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the max must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "amax", name)
- def reduce_min(values, index, name="segmented_reduce_min"):
- """
- Computes the minimum over segments.
- This operations computes the minimum over segments, with support for:
- - Batching using the first dimensions [B1, B2, ..., Bn]. Each element in a batch can have different indices.
- - Vectorization using the last dimension [V1, V2, ...]. If they are present, the output will be an element-wise
- minimum of vectors rather than scalars.
- Only the middle dimensions [I1, ..., Ik] are reduced by the operation.
- Args:
- values (`torch.Tensor` of shape [B1, B2, ..., Bn, I1, .., Ik, V1, V2, ..]):
- Tensor containing the values of which the min must be taken segment-wise.
- index (`IndexMap`, indices are of shape [B1, B2, ..., Bn, I1, .., Ik].):
- Index defining the segments.
- name (`str`, *optional*, defaults to 'segmented_reduce_sum'):
- Name for the operation. Currently not used
- Returns:
- output_values (`torch.Tensor`of shape [B1, B2, ..., Bn, num_segments, V1, V2, ..]): Tensor containing the
- output values. output_index (`IndexMap`): IndexMap with shape [B1, B2, ..., Bn, num_segments].
- """
- return _segment_reduce(values, index, "amin", name)
- # End of everything related to segmented tensors
- def compute_column_logits(
- sequence_output, column_output_weights, column_output_bias, cell_index, cell_mask, allow_empty_column_selection
- ):
- """
- Computes the column logits.
- Args:
- sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
- column_output_weights (`torch.FloatTensor` of shape `(hidden_size)`):
- Weights of the linear layer for column selection.
- column_output_bias (`torch.FloatTensor` of shape `()`):
- Bias of the linear layer for column selection.
- cell_index (`ProductIndexMap`):
- Index that groups tokens into cells.
- cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
- Mask for cells that exist in the table (i.e. that are not padding).
- allow_empty_column_selection (`bool`):
- Whether to allow not to select any column
- Returns:
- column_logits (`torch.FloatTensor`of shape `(batch_size, max_num_cols)`): Tensor containing the column logits
- for every example in the batch.
- """
- # First, compute the token logits (batch_size, seq_len) - without temperature
- token_logits = torch.einsum("bsj,j->bs", sequence_output, column_output_weights) + column_output_bias
- # Next, average the logits per cell (batch_size, max_num_cols*max_num_rows)
- cell_logits, cell_logits_index = reduce_mean(token_logits, cell_index)
- # Finally, average the logits per column (batch_size, max_num_cols)
- column_index = cell_index.project_inner(cell_logits_index)
- column_logits, out_index = reduce_sum(cell_logits * cell_mask, column_index)
- cell_count, _ = reduce_sum(cell_mask, column_index)
- column_logits /= cell_count + EPSILON_ZERO_DIVISION
- # Mask columns that do not appear in the example.
- is_padding = torch.logical_and(cell_count < 0.5, ~torch.eq(out_index.indices, 0))
- column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
- is_padding, dtype=torch.float32, device=is_padding.device
- )
- if not allow_empty_column_selection:
- column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * torch.as_tensor(
- torch.eq(out_index.indices, 0), dtype=torch.float32, device=out_index.indices.device
- )
- return column_logits
- def _single_column_cell_selection_loss(token_logits, column_logits, labels, cell_index, col_index, cell_mask):
- """
- Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The
- model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside
- the selected column are never selected.
- Args:
- token_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
- Tensor containing the logits per token.
- column_logits (`torch.FloatTensor` of shape `(batch_size, max_num_cols)`):
- Tensor containing the logits per column.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Labels per token.
- cell_index (`ProductIndexMap`):
- Index that groups tokens into cells.
- col_index (`IndexMap`):
- Index that groups tokens into columns.
- cell_mask (`torch.FloatTensor` of shape `(batch_size, max_num_rows * max_num_cols)`):
- Mask for cells that exist in the table (i.e. that are not padding).
- Returns:
- selection_loss_per_example (`torch.FloatTensor` of shape `(batch_size,)`): Loss for each example. logits
- (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): New logits which are only allowed to select
- cells in a single column. Logits outside of the most likely column according to *column_logits* will be set to
- a very low value (such that the probabilities are 0).
- """
- # Part 1: column loss
- # First find the column we should select. We use the column with maximum number of selected cells.
- labels_per_column, _ = reduce_sum(torch.as_tensor(labels, dtype=torch.float32, device=labels.device), col_index)
- # shape of labels_per_column is (batch_size, max_num_cols). It contains the number of label ids for every column, for every example
- column_label = torch.argmax(labels_per_column, dim=-1) # shape (batch_size,)
- # Check if there are no selected cells in the column. In that case the model
- # should predict the special column id 0, which means "select nothing".
- no_cell_selected = torch.eq(
- torch.max(labels_per_column, dim=-1)[0], 0
- ) # no_cell_selected is of shape (batch_size,) and equals True
- # if an example of the batch has no cells selected (i.e. if there are no labels set to 1 for that example)
- column_label = torch.where(
- no_cell_selected.view(column_label.size()), torch.zeros_like(column_label), column_label
- )
- column_dist = torch.distributions.Categorical(logits=column_logits) # shape (batch_size, max_num_cols)
- column_loss_per_example = -column_dist.log_prob(column_label)
- # Part 2: cell loss
- # Reduce the labels and logits to per-cell from per-token.
- # logits_per_cell: shape (batch_size, max_num_rows*max_num_cols) i.e. (batch_size, 64*32)
- logits_per_cell, _ = reduce_mean(token_logits, cell_index)
- # labels_per_cell: shape (batch_size, 64*32), indicating whether each cell should be selected (1) or not (0)
- labels_per_cell, labels_index = reduce_max(
- torch.as_tensor(labels, dtype=torch.long, device=labels.device), cell_index
- )
- # Mask for the selected column.
- # column_id_for_cells: shape (batch_size, 64*32), indicating to which column each cell belongs
- column_id_for_cells = cell_index.project_inner(labels_index).indices
- # column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column to be selected
- column_mask = torch.as_tensor(
- torch.eq(column_id_for_cells, torch.unsqueeze(column_label, dim=-1)),
- dtype=torch.float32,
- device=cell_mask.device,
- )
- # Compute the log-likelihood for cells, but only for the selected column.
- cell_dist = torch.distributions.Bernoulli(logits=logits_per_cell) # shape (batch_size, 64*32)
- cell_log_prob = cell_dist.log_prob(labels_per_cell.type(torch.float32)) # shape(batch_size, 64*32)
- cell_loss = -torch.sum(cell_log_prob * column_mask * cell_mask, dim=1)
- # We need to normalize the loss by the number of cells in the column.
- cell_loss /= torch.sum(column_mask * cell_mask, dim=1) + EPSILON_ZERO_DIVISION
- selection_loss_per_example = column_loss_per_example
- selection_loss_per_example += torch.where(
- no_cell_selected.view(selection_loss_per_example.size()),
- torch.zeros_like(selection_loss_per_example),
- cell_loss,
- )
- # Set the probs outside the selected column (selected by the *model*)
- # to 0. This ensures backwards compatibility with models that select
- # cells from multiple columns.
- selected_column_id = torch.as_tensor(
- torch.argmax(column_logits, dim=-1), dtype=torch.long, device=column_logits.device
- ) # shape (batch_size,)
- # selected_column_mask: shape (batch_size, 64*32), equal to 1 if cell belongs to column selected by the model
- selected_column_mask = torch.as_tensor(
- torch.eq(column_id_for_cells, torch.unsqueeze(selected_column_id, dim=-1)),
- dtype=torch.float32,
- device=selected_column_id.device,
- )
- # Never select cells with the special column id 0.
- selected_column_mask = torch.where(
- torch.eq(column_id_for_cells, 0).view(selected_column_mask.size()),
- torch.zeros_like(selected_column_mask),
- selected_column_mask,
- )
- new_logits_per_cell = logits_per_cell + CLOSE_ENOUGH_TO_LOG_ZERO * (1.0 - cell_mask * selected_column_mask)
- logits = gather(new_logits_per_cell, cell_index)
- return selection_loss_per_example, logits
- def compute_token_logits(sequence_output, temperature, output_weights, output_bias):
- """
- Computes logits per token
- Args:
- sequence_output (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Also known as last_hidden_state. Sequence of hidden-states at the output of the last layer of the model.
- temperature (`float`):
- Temperature for the Bernoulli distribution.
- output_weights (`torch.FloatTensor` of shape `(hidden_size,)`):
- Weights of the linear layer for cell selection.
- output_bias (`torch.FloatTensor` of shape `()`):
- Bias of the linear layer for cell selection
- Returns:
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): Logits per token.
- """
- logits = (torch.einsum("bsj,j->bs", sequence_output, output_weights) + output_bias) / temperature
- return logits
- def _calculate_aggregate_mask(answer, pooled_output, cell_selection_preference, labels, aggregation_classifier):
- """
- Finds examples where the model should select cells with no aggregation.
- Returns a mask that determines for which examples should the model select answers directly from the table, without
- any aggregation function. If the answer is a piece of text the case is unambiguous as aggregation functions only
- apply to numbers. If the answer is a number but does not appear in the table then we must use some aggregation
- case. The ambiguous case is when the answer is a number that also appears in the table. In this case we use the
- aggregation function probabilities predicted by the model to decide whether to select or aggregate. The threshold
- for this is a hyperparameter *cell_selection_preference*
- Args:
- answer (`torch.FloatTensor` of shape `(batch_size, )`):
- Answer for every example in the batch. Nan if there is no scalar answer.
- pooled_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
- Output of the pooler (BertPooler) on top of the encoder layer.
- cell_selection_preference (`float`):
- Preference for cell selection in ambiguous cases.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Labels per token. aggregation_classifier (`torch.nn.Linear`): Aggregation head
- Returns:
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask set to 1 for examples that should use
- aggregation functions.
- """
- # torch.FloatTensor(batch_size,)
- aggregate_mask_init = torch.logical_not(torch.isnan(answer)).type(torch.FloatTensor).to(answer.device)
- logits_aggregation = aggregation_classifier(pooled_output)
- dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
- # Index 0 corresponds to "no aggregation".
- aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
- # Cell selection examples according to current model.
- is_pred_cell_selection = aggregation_ops_total_mass <= cell_selection_preference
- # Examples with non-empty cell selection supervision.
- is_cell_supervision_available = torch.sum(labels, dim=1) > 0
- # torch.where is not equivalent to tf.where (in tensorflow 1)
- # hence the added .view on the condition to match the shape of the first tensor
- aggregate_mask = torch.where(
- torch.logical_and(is_pred_cell_selection, is_cell_supervision_available).view(aggregate_mask_init.size()),
- torch.zeros_like(aggregate_mask_init, dtype=torch.float32),
- aggregate_mask_init,
- )
- aggregate_mask = aggregate_mask.detach()
- return aggregate_mask
- def _calculate_aggregation_loss_known(
- logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
- ):
- """
- Calculates aggregation loss when its type is known during training.
- In the weakly supervised setting, the only known information is that for cell selection examples, "no aggregation"
- should be predicted. For other examples (those that require aggregation), no loss is accumulated. In the setting
- where aggregation type is always known, standard cross entropy loss is accumulated for all examples
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
- Aggregation function id for every example in the batch.
- use_answer_as_supervision (`bool`, *optional*):
- Whether to use the answer as the only supervision for aggregation examples.
- num_aggregation_labels (`int`, *optional*, defaults to 0):
- The number of aggregation operators to predict.
- Returns:
- aggregation_loss_known (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (when its type is known
- during training) per example.
- """
- if use_answer_as_supervision:
- # Prepare "no aggregation" targets for cell selection examples.
- target_aggregation = torch.zeros_like(aggregate_mask, dtype=torch.long)
- else:
- # Use aggregation supervision as the target.
- target_aggregation = aggregation_labels
- one_hot_labels = nn.functional.one_hot(target_aggregation, num_classes=num_aggregation_labels).type(torch.float32)
- log_probs = nn.functional.log_softmax(logits_aggregation, dim=-1)
- # torch.FloatTensor[batch_size]
- per_example_aggregation_intermediate = -torch.sum(one_hot_labels * log_probs, dim=-1)
- if use_answer_as_supervision:
- # Accumulate loss only for examples requiring cell selection
- # (no aggregation).
- return per_example_aggregation_intermediate * (1 - aggregate_mask)
- else:
- return per_example_aggregation_intermediate
- def _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask):
- """
- Calculates aggregation loss in the case of answer supervision.
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions
- Returns:
- aggregation_loss_unknown (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss (in case of answer
- supervision) per example.
- """
- dist_aggregation = torch.distributions.categorical.Categorical(logits=logits_aggregation)
- # Index 0 corresponds to "no aggregation".
- aggregation_ops_total_mass = torch.sum(dist_aggregation.probs[:, 1:], dim=1)
- # Predict some aggregation in case of an answer that needs aggregation.
- # This increases the probability of all aggregation functions, in a way
- # similar to MML, but without considering whether the function gives the
- # correct answer.
- return -torch.log(aggregation_ops_total_mass) * aggregate_mask
- def _calculate_aggregation_loss(
- logits_aggregation,
- aggregate_mask,
- aggregation_labels,
- use_answer_as_supervision,
- num_aggregation_labels,
- aggregation_loss_weight,
- ):
- """
- Calculates the aggregation loss per example.
- Args:
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size, )`):
- A mask set to 1 for examples that should use aggregation functions.
- aggregation_labels (`torch.LongTensor` of shape `(batch_size, )`):
- Aggregation function id for every example in the batch.
- use_answer_as_supervision (`bool`, *optional*):
- Whether to use the answer as the only supervision for aggregation examples.
- num_aggregation_labels (`int`, *optional*, defaults to 0):
- The number of aggregation operators to predict.
- aggregation_loss_weight (`float`, *optional*, defaults to 1.0):
- Importance weight for the aggregation loss.
- Returns:
- aggregation_loss (`torch.FloatTensor` of shape `(batch_size,)`): Aggregation loss per example.
- """
- per_example_aggregation_loss = _calculate_aggregation_loss_known(
- logits_aggregation, aggregate_mask, aggregation_labels, use_answer_as_supervision, num_aggregation_labels
- )
- if use_answer_as_supervision:
- # Add aggregation loss for numeric answers that need aggregation.
- per_example_aggregation_loss += _calculate_aggregation_loss_unknown(logits_aggregation, aggregate_mask)
- return aggregation_loss_weight * per_example_aggregation_loss
- def _calculate_expected_result(
- dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
- ):
- """
- Calculates the expected result given cell and aggregation probabilities.
- Args:
- dist_per_cell (`torch.distributions.Bernoulli`):
- Cell selection distribution for each cell.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Numeric values of every token. Nan for tokens which are not numeric values.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Scale of the numeric values of every token.
- input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Mask for the table, without question tokens and table headers.
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- config ([`TapasConfig`]):
- Model configuration class with all the hyperparameters of the model
- Returns:
- expected_result (`torch.FloatTensor` of shape `(batch_size,)`): The expected result per example.
- """
- if config.use_gumbel_for_cells:
- gumbel_dist = torch.distributions.RelaxedBernoulli(
- # The token logits where already divided by the temperature and used for
- # computing cell selection errors so we need to multiply it again here
- temperature=config.temperature,
- logits=dist_per_cell.logits * config.temperature,
- )
- scaled_probability_per_cell = gumbel_dist.sample()
- else:
- scaled_probability_per_cell = dist_per_cell.probs
- # <float32>[batch_size, seq_length]
- scaled_probability_per_cell = (scaled_probability_per_cell / numeric_values_scale) * input_mask_float
- count_result = torch.sum(scaled_probability_per_cell, dim=1)
- numeric_values_masked = torch.where(
- torch.isnan(numeric_values), torch.zeros_like(numeric_values), numeric_values
- ) # Mask non-numeric table values to zero.
- sum_result = torch.sum(scaled_probability_per_cell * numeric_values_masked, dim=1)
- avg_approximation = config.average_approximation_function
- if avg_approximation == AverageApproximationFunction.RATIO:
- average_result = sum_result / (count_result + EPSILON_ZERO_DIVISION)
- elif avg_approximation == AverageApproximationFunction.FIRST_ORDER:
- # The sum of all probabilities except that correspond to other cells
- # Ex here stands for expectation, more explicitly the expectation of the sum of N-1 Bernoulli random variables plus
- # the constant 1, which is computed as adding all N expected values and subtracting the extra one. It corresponds to X_c
- # in Appendix D of the original TAPAS paper which is trying to approximate the average of a random set.
- ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
- average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell / ex, dim=1)
- elif avg_approximation == AverageApproximationFunction.SECOND_ORDER:
- # The sum of all probabilities except that correspond to other cells
- ex = torch.sum(scaled_probability_per_cell, dim=1, keepdim=True) - scaled_probability_per_cell + 1
- pointwise_var = scaled_probability_per_cell * (1 - scaled_probability_per_cell)
- var = torch.sum(pointwise_var, dim=1, keepdim=True) - pointwise_var
- multiplier = (var / torch.square(ex) + 1) / ex
- average_result = torch.sum(numeric_values_masked * scaled_probability_per_cell * multiplier, dim=1)
- else:
- raise ValueError(f"Invalid average_approximation_function: {config.average_approximation_function}")
- if config.use_gumbel_for_aggregation:
- gumbel_dist = torch.distributions.RelaxedOneHotCategorical(
- config.aggregation_temperature, logits=logits_aggregation[:, 1:]
- )
- # <float32>[batch_size, num_aggregation_labels - 1]
- aggregation_op_only_probs = gumbel_dist.sample()
- else:
- # <float32>[batch_size, num_aggregation_labels - 1]
- aggregation_op_only_probs = nn.functional.softmax(
- logits_aggregation[:, 1:] / config.aggregation_temperature, dim=-1
- )
- all_results = torch.cat(
- [
- torch.unsqueeze(sum_result, dim=1),
- torch.unsqueeze(average_result, dim=1),
- torch.unsqueeze(count_result, dim=1),
- ],
- dim=1,
- )
- expected_result = torch.sum(all_results * aggregation_op_only_probs, dim=1)
- return expected_result
- # PyTorch does not currently support Huber loss with custom delta so we define it ourself
- def huber_loss(input, target, delta: float = 1.0):
- errors = torch.abs(input - target) # shape (batch_size,)
- return torch.where(errors < delta, 0.5 * errors**2, errors * delta - (0.5 * delta**2))
- def _calculate_regression_loss(
- answer,
- aggregate_mask,
- dist_per_cell,
- numeric_values,
- numeric_values_scale,
- input_mask_float,
- logits_aggregation,
- config,
- ):
- """
- Calculates the regression loss per example.
- Args:
- answer (`torch.FloatTensor` of shape `(batch_size,)`):
- Answer for every example in the batch. Nan if there is no scalar answer.
- aggregate_mask (`torch.FloatTensor` of shape `(batch_size,)`):
- A mask set to 1 for examples that should use aggregation functions.
- dist_per_cell (`torch.distributions.Bernoulli`):
- Cell selection distribution for each cell.
- numeric_values (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Numeric values of every token. Nan for tokens which are not numeric values.
- numeric_values_scale (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Scale of the numeric values of every token.
- input_mask_float (`torch.FloatTensor` of shape `(batch_size, seq_length)`):
- Mask for the table, without question tokens and table headers.
- logits_aggregation (`torch.FloatTensor` of shape `(batch_size, num_aggregation_labels)`):
- Logits per aggregation operation.
- config ([`TapasConfig`]):
- Model configuration class with all the parameters of the model
- Returns:
- per_example_answer_loss_scaled (`torch.FloatTensor` of shape `(batch_size,)`): Scales answer loss for each
- example in the batch. large_answer_loss_mask (`torch.FloatTensor` of shape `(batch_size,)`): A mask which is 1
- for examples for which their answer loss is larger than the answer_loss_cutoff.
- """
- # float32 (batch_size,)
- expected_result = _calculate_expected_result(
- dist_per_cell, numeric_values, numeric_values_scale, input_mask_float, logits_aggregation, config
- )
- # float32 (batch_size,)
- answer_masked = torch.where(torch.isnan(answer), torch.zeros_like(answer), answer)
- if config.use_normalized_answer_loss:
- normalizer = (torch.max(torch.abs(expected_result), torch.abs(answer_masked)) + EPSILON_ZERO_DIVISION).detach()
- normalized_answer_masked = answer_masked / normalizer
- normalized_expected_result = expected_result / normalizer
- per_example_answer_loss = huber_loss(
- normalized_expected_result * aggregate_mask, normalized_answer_masked * aggregate_mask
- )
- else:
- per_example_answer_loss = huber_loss(
- expected_result * aggregate_mask, answer_masked * aggregate_mask, delta=config.huber_loss_delta
- )
- if config.answer_loss_cutoff is None:
- large_answer_loss_mask = torch.ones_like(per_example_answer_loss, dtype=torch.float32)
- else:
- large_answer_loss_mask = torch.where(
- per_example_answer_loss > config.answer_loss_cutoff,
- torch.zeros_like(per_example_answer_loss, dtype=torch.float32),
- torch.ones_like(per_example_answer_loss, dtype=torch.float32),
- )
- per_example_answer_loss_scaled = config.answer_loss_importance * (per_example_answer_loss * aggregate_mask)
- return per_example_answer_loss_scaled, large_answer_loss_mask
|