| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894 |
- # coding=utf-8
- # Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Pix2Struct modeling file"""
- import math
- from typing import Dict, List, Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import AttentionMaskConverter
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPooling,
- CausalLMOutputWithCrossAttentions,
- Seq2SeqLMOutput,
- Seq2SeqModelOutput,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import ALL_LAYERNORM_LAYERS
- from ...utils import (
- DUMMY_INPUTS,
- DUMMY_MASK,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- is_torch_fx_proxy,
- is_torchdynamo_compiling,
- logging,
- replace_return_docstrings,
- )
- from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
- logger = logging.get_logger(__name__)
- # General docstring
- _CONFIG_FOR_DOC = "Pix2StructConfig"
- # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct
- class Pix2StructLayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
- # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
- # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
- # half-precision inputs is done in fp32
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return self.weight * hidden_states
- try:
- from apex.normalization import FusedRMSNorm
- Pix2StructLayerNorm = FusedRMSNorm # noqa
- logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm")
- except ImportError:
- # using the normal Pix2StructLayerNorm
- pass
- except Exception:
- logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm")
- pass
- ALL_LAYERNORM_LAYERS.append(Pix2StructLayerNorm)
- class Pix2StructVisionEmbeddings(nn.Module):
- r"""
- Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
- Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
- is represented by a vector of `hidden_size` values.
- """
- def __init__(self, config: Pix2StructConfig) -> None:
- super().__init__()
- self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size)
- self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size)
- self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
- # the row and column indices are stored in the first and second position of the flattened_patches
- # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2
- row_indices = flattened_patches[:, :, 0].long()
- col_indices = flattened_patches[:, :, 1].long()
- flattened_patches = flattened_patches[:, :, 2:]
- embeddings = self.patch_projection(flattened_patches)
- row_embeddings = self.row_embedder(row_indices)
- col_embeddings = self.column_embedder(col_indices)
- # sum all embeddings together
- embeddings = embeddings + row_embeddings + col_embeddings
- embeddings = self.dropout(embeddings)
- return embeddings
- class Pix2StructVisionAttention(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_attention_heads
- self.dropout = config.attention_dropout
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- # Mesh TensorFlow initialization to avoid scaling before softmax
- self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
- self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
- self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
- self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- layer_head_mask=None,
- output_attentions=False,
- ):
- """
- Self-attention block
- """
- # Input is (batch_size, seq_length, dim)
- # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
- # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
- batch_size, seq_length = hidden_states.shape[:2]
- def to_projection_shape(states):
- """projection"""
- return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
- # get query states
- # (batch_size, n_heads, seq_length, dim_per_head)
- query_states = to_projection_shape(self.query(hidden_states))
- # get key/value states
- key_states = to_projection_shape(self.key(hidden_states))
- value_states = to_projection_shape(self.value(hidden_states))
- # compute scores
- # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
- if position_bias is None:
- position_bias = torch.zeros(
- (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- if attention_mask.dim() == 2:
- position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
- elif attention_mask is not None:
- # (batch_size, n_heads, seq_length, key_length)
- position_bias = position_bias + attention_mask.to(position_bias.device)
- elif not is_torchdynamo_compiling():
- attention_mask = torch.ones(
- (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype
- )
- position_bias = position_bias + attention_mask.to(position_bias.device)
- position_bias = 1 - position_bias
- position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
- scores += position_bias_masked
- scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores)
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- # Mask heads if we want to
- if layer_head_mask is not None:
- attn_weights = attn_weights * layer_head_mask
- attn_output = torch.matmul(attn_weights, value_states)
- # (batch_size, seq_length, dim)
- attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
- attn_output = self.output(attn_output)
- outputs = (attn_output,) + (position_bias,)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate
- class Pix2StructVisionMlp(nn.Module):
- def __init__(self, config: Pix2StructVisionConfig):
- super().__init__()
- self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
- self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_gelu = self.act(self.wi_0(hidden_states))
- hidden_linear = self.wi_1(hidden_states)
- hidden_states = hidden_gelu * hidden_linear
- hidden_states = self.dropout(hidden_states)
- # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
- # See https://github.com/huggingface/transformers/issues/20287
- # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- class Pix2StructVisionLayer(nn.Module):
- def __init__(self, config: Pix2StructConfig) -> None:
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = Pix2StructVisionAttention(config)
- self.mlp = Pix2StructVisionMlp(config)
- self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
- residual = hidden_states
- # in Pix2StructVision, layernorm is applied before self-attention
- hidden_states = self.pre_attention_layer_norm(hidden_states)
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- layer_head_mask=head_mask,
- output_attentions=output_attentions,
- )
- attention_output = self_attention_outputs[0]
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
- # first residual connection
- hidden_states = attention_output + residual
- # in Pix2StructVision, layernorm is also applied after self-attention
- layer_output = self.pre_mlp_layer_norm(hidden_states)
- layer_output = self.mlp(layer_output) + hidden_states # second residual connection
- outputs = (layer_output,) + outputs
- return outputs
- class Pix2StructVisionEncoder(nn.Module):
- def __init__(self, config: Pix2StructConfig) -> None:
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ) -> Union[tuple, BaseModelOutput]:
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_head_mask = head_mask[i] if head_mask is not None else None
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- layer_module.__call__,
- hidden_states,
- attention_mask,
- layer_head_mask,
- output_attentions,
- )
- else:
- layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class Pix2StructPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = Pix2StructConfig
- _supports_cache_class = True
- _supports_static_cache = False
- @property
- def dummy_inputs(self):
- input_ids = torch.tensor(DUMMY_INPUTS)
- input_mask = torch.tensor(DUMMY_MASK)
- dummy_inputs = {
- "decoder_input_ids": input_ids,
- "input_ids": input_ids,
- "decoder_attention_mask": input_mask,
- }
- return dummy_inputs
- def _init_weights(self, module):
- """Initialize the weights"""
- factor = self.config.initializer_factor # Used for testing weights initialization
- if isinstance(module, Pix2StructLayerNorm):
- module.weight.data.fill_(factor * 1.0)
- elif isinstance(module, Pix2StructTextDenseGatedActDense):
- hidden_size = (
- self.config.text_config.hidden_size
- if isinstance(self.config, Pix2StructConfig)
- else self.config.hidden_size
- )
- d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff
- module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
- if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
- module.wi_0.bias.data.zero_()
- module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
- if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
- module.wi_1.bias.data.zero_()
- module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
- if hasattr(module.wo, "bias") and module.wo.bias is not None:
- module.wo.bias.data.zero_()
- elif isinstance(module, Pix2StructTextAttention):
- # Mesh TensorFlow attention initialization to avoid scaling before softmax
- # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
- hidden_size = (
- self.config.text_config.hidden_size
- if isinstance(self.config, Pix2StructConfig)
- else self.config.hidden_size
- )
- key_value_proj_dim = (
- self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size
- )
- n_heads = (
- self.config.text_config.num_heads
- if isinstance(self.config, Pix2StructConfig)
- else self.config.num_heads
- )
- module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5))
- module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
- module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
- module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
- if module.has_relative_attention_bias:
- module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
- elif isinstance(module, nn.Embedding):
- hidden_size = (
- self.config.text_config.hidden_size
- if isinstance(self.config, Pix2StructConfig)
- else self.config.hidden_size
- )
- module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, Pix2StructTextModel):
- hidden_size = (
- self.config.text_config.hidden_size
- if isinstance(self.config, Pix2StructConfig)
- else self.config.hidden_size
- )
- module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
- elif isinstance(module, (nn.Linear, nn.Conv2d)):
- # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
- # `trunc_normal_cpu` not implemented in `half` issues
- module.weight.data = nn.init.trunc_normal_(
- module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
- ).to(module.weight.dtype)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, Pix2StructLayerNorm):
- if module.weight is not None:
- module.weight.data.fill_(1.0)
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct
- def _shift_right(self, input_ids):
- decoder_start_token_id = self.config.decoder_start_token_id
- pad_token_id = self.config.pad_token_id
- if decoder_start_token_id is None:
- raise ValueError(
- "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. "
- "See Pix2Struct docs for more information."
- )
- # shift inputs to the right
- if is_torch_fx_proxy(input_ids):
- # Item assignment is not supported natively for proxies.
- shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
- shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
- else:
- shifted_input_ids = input_ids.new_zeros(input_ids.shape)
- shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
- shifted_input_ids[..., 0] = decoder_start_token_id
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
- return shifted_input_ids
- PIX2STRUCT_VISION_START_DOCSTRING = r"""
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
- as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
- behavior.
- Parameters:
- config ([`Pix2StructConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- PIX2STRUCT_VISION_INPUTS_DOCSTRING = r"""
- Args:
- flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
- Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See
- [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original
- paper](https://arxiv.org/abs/2210.03347) (figure 5) for more details.
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "The bare Pix2StructVision Model transformer outputting raw hidden-states without any specific head on top.",
- PIX2STRUCT_VISION_START_DOCSTRING,
- )
- class Pix2StructVisionModel(Pix2StructPreTrainedModel):
- config_class = Pix2StructVisionConfig
- main_input_name = "flattened_patches"
- supports_gradient_checkpointing = True
- _no_split_modules = ["Pix2StructVisionLayer"]
- def __init__(self, config: Pix2StructConfig):
- super().__init__(config)
- self.config = config
- self.embeddings = Pix2StructVisionEmbeddings(config)
- self.encoder = Pix2StructVisionEncoder(config)
- self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.patch_projection
- def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- for layer, heads in heads_to_prune.items():
- self.encoder.layer[layer].attention.prune_heads(heads)
- @add_start_docstrings_to_model_forward(PIX2STRUCT_VISION_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- flattened_patches: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
- r"""
- Returns:
- Example:
- ```python
- >>> import requests
- >>> from PIL import Image
- >>> from transformers import AutoProcessor, Pix2StructVisionModel
- >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
- >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> inputs = image_processor(images=image, return_tensors="pt")
- >>> with torch.no_grad():
- ... outputs = model(**inputs)
- >>> last_hidden_states = outputs.last_hidden_state
- >>> list(last_hidden_states.shape)
- [1, 2048, 768]
- ```
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if flattened_patches is None:
- raise ValueError("You have to specify flattened_patches")
- if attention_mask is None:
- # check where `flattened_patches` is not 0
- attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
- embedding_output = self.embeddings(flattened_patches)
- encoder_outputs = self.encoder(
- embedding_output,
- attention_mask=attention_mask,
- head_mask=head_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = encoder_outputs[0]
- sequence_output = self.layernorm(sequence_output)
- if not return_dict:
- head_outputs = (sequence_output,)
- return head_outputs + encoder_outputs[1:]
- return BaseModelOutput(
- last_hidden_state=sequence_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size
- class Pix2StructTextDenseGatedActDense(nn.Module):
- def __init__(self, config: Pix2StructTextConfig):
- super().__init__()
- self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
- self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
- self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.act = ACT2FN[config.dense_act_fn]
- def forward(self, hidden_states):
- hidden_gelu = self.act(self.wi_0(hidden_states))
- hidden_linear = self.wi_1(hidden_states)
- hidden_states = hidden_gelu * hidden_linear
- hidden_states = self.dropout(hidden_states)
- # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
- # See https://github.com/huggingface/transformers/issues/20287
- # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
- if (
- isinstance(self.wo.weight, torch.Tensor)
- and hidden_states.dtype != self.wo.weight.dtype
- and self.wo.weight.dtype != torch.int8
- ):
- hidden_states = hidden_states.to(self.wo.weight.dtype)
- hidden_states = self.wo(hidden_states)
- return hidden_states
- class Pix2StructTextLayerFF(nn.Module):
- def __init__(self, config: Pix2StructTextConfig):
- super().__init__()
- self.DenseReluDense = Pix2StructTextDenseGatedActDense(config)
- self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward
- def forward(self, hidden_states):
- forwarded_states = self.layer_norm(hidden_states)
- forwarded_states = self.DenseReluDense(forwarded_states)
- hidden_states = hidden_states + self.dropout(forwarded_states)
- return hidden_states
- class Pix2StructTextAttention(nn.Module):
- def __init__(
- self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None
- ):
- super().__init__()
- self.has_relative_attention_bias = has_relative_attention_bias
- self.relative_attention_num_buckets = config.relative_attention_num_buckets
- self.relative_attention_max_distance = config.relative_attention_max_distance
- self.hidden_size = config.hidden_size
- self.key_value_proj_dim = config.d_kv
- self.n_heads = config.num_heads
- self.dropout = config.dropout_rate
- self.inner_dim = self.n_heads * self.key_value_proj_dim
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- # Mesh TensorFlow initialization to avoid scaling before softmax
- self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- if self.has_relative_attention_bias:
- self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
- self.pruned_heads = set()
- self.gradient_checkpointing = False
- @staticmethod
- # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
- def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
- """
- Adapted from Mesh Tensorflow:
- https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
- Translate relative position to a bucket number for relative attention. The relative position is defined as
- memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
- small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
- This should allow for more graceful generalization to longer sequences than the model has been trained on
- Args:
- relative_position: an int32 Tensor
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
- """
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
- # now relative_position is in the range [0, inf)
- # half of the buckets are for exact increments in positions
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
- )
- relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
- return relative_buckets
- # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
- def compute_bias(self, query_length, key_length, device=None, cache_position=None):
- """Compute binned relative position bias"""
- if device is None:
- device = self.relative_attention_bias.weight.device
- if cache_position is None:
- context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
- else:
- context_position = cache_position[:, None].to(device)
- memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
- relative_position = memory_position - context_position # shape (query_length, key_length)
- relative_position_bucket = self._relative_position_bucket(
- relative_position, # shape (query_length, key_length)
- bidirectional=False,
- num_buckets=self.relative_attention_num_buckets,
- max_distance=self.relative_attention_max_distance,
- )
- values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
- values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
- return values
- # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
- def forward(
- self,
- hidden_states,
- mask=None,
- key_value_states=None,
- position_bias=None,
- past_key_value=None,
- layer_head_mask=None,
- query_length=None,
- use_cache=False,
- output_attentions=False,
- cache_position=None,
- ):
- """
- Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
- """
- # Input is (batch_size, seq_length, dim)
- # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
- batch_size, seq_length = hidden_states.shape[:2]
- # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
- is_cross_attention = key_value_states is not None
- query_states = self.query(hidden_states)
- query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
- if past_key_value is not None:
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.key(current_states)
- value_states = self.value(current_states)
- key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
- value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
- # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
- scores = torch.matmul(query_states, key_states.transpose(3, 2))
- if position_bias is None:
- key_length = key_states.shape[-2]
- # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
- real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
- if not self.has_relative_attention_bias:
- position_bias = torch.zeros(
- (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
- )
- if self.gradient_checkpointing and self.training:
- position_bias.requires_grad = True
- else:
- position_bias = self.compute_bias(
- real_seq_length, key_length, device=scores.device, cache_position=cache_position
- )
- position_bias = position_bias[:, :, -seq_length:, :]
- if mask is not None:
- causal_mask = mask[:, :, :, : key_states.shape[-2]]
- position_bias = position_bias + causal_mask
- if self.pruned_heads:
- mask = torch.ones(position_bias.shape[1])
- mask[list(self.pruned_heads)] = 0
- position_bias_masked = position_bias[:, mask.bool()]
- else:
- position_bias_masked = position_bias
- scores += position_bias_masked
- # (batch_size, n_heads, seq_length, key_length)
- attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
- attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- # Mask heads if we want to
- if layer_head_mask is not None:
- attn_weights = attn_weights * layer_head_mask
- attn_output = torch.matmul(attn_weights, value_states)
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.view(batch_size, -1, self.inner_dim)
- attn_output = self.output(attn_output)
- outputs = (attn_output, past_key_value, position_bias)
- if output_attentions:
- outputs = outputs + (attn_weights,)
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size
- class Pix2StructTextLayerSelfAttention(nn.Module):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
- super().__init__()
- self.attention = Pix2StructTextAttention(
- config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
- )
- self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- layer_head_mask=None,
- past_key_value=None,
- use_cache=False,
- output_attentions=False,
- cache_position=None,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.attention(
- normed_hidden_states,
- mask=attention_mask,
- position_bias=position_bias,
- layer_head_mask=layer_head_mask,
- past_key_value=past_key_value,
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states = hidden_states + self.dropout(attention_output[0])
- outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size
- class Pix2StructTextLayerCrossAttention(nn.Module):
- def __init__(self, config, layer_idx: Optional[int] = None):
- super().__init__()
- self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
- self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- def forward(
- self,
- hidden_states,
- key_value_states,
- attention_mask=None,
- position_bias=None,
- layer_head_mask=None,
- past_key_value=None,
- use_cache=False,
- query_length=None,
- output_attentions=False,
- cache_position=None,
- ):
- normed_hidden_states = self.layer_norm(hidden_states)
- attention_output = self.attention(
- normed_hidden_states,
- mask=attention_mask,
- key_value_states=key_value_states,
- position_bias=position_bias,
- layer_head_mask=layer_head_mask,
- past_key_value=past_key_value,
- use_cache=use_cache,
- query_length=query_length,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- layer_output = hidden_states + self.dropout(attention_output[0])
- outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
- return outputs
- class Pix2StructTextBlock(nn.Module):
- def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
- super().__init__()
- self.self_attention = Pix2StructTextLayerSelfAttention(
- config,
- has_relative_attention_bias=has_relative_attention_bias,
- layer_idx=layer_idx,
- )
- self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
- config,
- layer_idx=layer_idx,
- )
- self.mlp = Pix2StructTextLayerFF(config)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- position_bias=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- encoder_decoder_position_bias=None,
- layer_head_mask=None,
- cross_attn_layer_head_mask=None,
- past_key_value=None,
- use_cache=False,
- output_attentions=False,
- return_dict=True,
- cache_position=None,
- ):
- self_attention_outputs = self.self_attention(
- hidden_states,
- attention_mask=attention_mask,
- position_bias=position_bias,
- layer_head_mask=layer_head_mask,
- past_key_value=past_key_value,
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states, past_key_value = self_attention_outputs[:2]
- attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- do_cross_attention = encoder_hidden_states is not None
- if do_cross_attention:
- cross_attention_outputs = self.encoder_decoder_attention(
- hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- position_bias=encoder_decoder_position_bias,
- layer_head_mask=cross_attn_layer_head_mask,
- past_key_value=past_key_value,
- query_length=cache_position[-1] + 1,
- use_cache=use_cache,
- output_attentions=output_attentions,
- )
- hidden_states, past_key_value = cross_attention_outputs[:2]
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- # Keep cross-attention outputs and relative position weights
- attention_outputs = attention_outputs + cross_attention_outputs[2:]
- # Apply Feed Forward layer
- hidden_states = self.mlp(hidden_states)
- # clamp inf values to enable fp16 training
- if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- outputs = (hidden_states,)
- if use_cache:
- outputs = outputs + (past_key_value,) + attention_outputs
- else:
- outputs = outputs + attention_outputs
- return outputs
- PIX2STRUCT_START_DOCSTRING = r"""
- The Pix2Struct model was proposed in [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language
- Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu,
- Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. It's an encoder decoder
- transformer pre-trained in a image-to-text setting.
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config (Union[`Pix2StructConfig`, `Pix2StructTextConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- PIX2STRUCT_TEXT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position
- embeddings so you should be able to pad the inputs on both the right and the left.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for detail.
- [What are input IDs?](../glossary#input-ids)
- To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText
- Training](./t5#training).
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
- 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
- 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
- `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
- Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
- the output of the last layer of the encoder. Used in the cross-attention of the decoder.
- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
- Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
- representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
- input (see `past_key_values`). This is useful if you want more control over how to convert
- `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
- If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
- of `inputs_embeds`.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
- cache in the correct position and to infer the complete sequence length.
- """
- PIX2STRUCT_INPUTS_DOCSTRING = r"""
- Args:
- flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
- Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
- `num_channels` * `patch_size` * `patch_size`
- The process of flattening the pixel patches is done by `Pix2StructProcessor`.
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Indices of decoder input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are decoder input IDs?](../glossary#decoder-input-ids)
- Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
- `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
- Training](./t5#training).
- decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
- Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
- be used by default.
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
- 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
- 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
- `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
- Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
- `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
- the output of the last layer of the encoder. Used in the cross-attention of the decoder.
- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
- Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
- representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
- input (see `past_key_values`). This is useful if you want more control over how to convert
- `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
- If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
- of `inputs_embeds`.
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss for the decoder.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- @add_start_docstrings(
- "The standalone text decoder of Pix2Struct",
- PIX2STRUCT_START_DOCSTRING,
- )
- class Pix2StructTextModel(Pix2StructPreTrainedModel):
- config_class = Pix2StructTextConfig
- _no_split_modules = ["Pix2StructTextBlock"]
- _tied_weights_keys = ["lm_head.weight"]
- supports_gradient_checkpointing = True
- def __init__(self, config):
- super().__init__(config)
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
- self.layer = nn.ModuleList(
- [
- Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
- for i in range(config.num_layers)
- ]
- )
- self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
- self.dropout = nn.Dropout(config.dropout_rate)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- self.gradient_checkpointing = False
- # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._reorder_cache
- def _reorder_cache(self, past_key_values, beam_idx):
- # if decoder past is not included in output
- # speedy decoding is disabled and no need to reorder
- if past_key_values is None:
- logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
- return past_key_values
- reordered_decoder_past = ()
- for layer_past_states in past_key_values:
- # get the correct batch idx from layer past batch dim
- # batch dim of `past` is at 2nd position
- reordered_layer_past_states = ()
- for layer_past_state in layer_past_states:
- # need to set correct `past` for each of the four key / value states
- reordered_layer_past_states = reordered_layer_past_states + (
- layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
- )
- if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
- raise ValueError(
- f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
- )
- if len(reordered_layer_past_states) != len(layer_past_states):
- raise ValueError(
- f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
- )
- reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
- return reordered_decoder_past
- def get_input_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, new_embeddings):
- self.embed_tokens = new_embeddings
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- @add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- labels: Optional[torch.LongTensor] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs,
- ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
- r"""
- Returns:
- Example:
- ```python
- >>> from transformers import AutoProcessor, Pix2StructTextModel
- >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
- >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")
- >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> loss = outputs.loss
- ```
- """
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
- elif input_ids is not None:
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
- if inputs_embeds is None:
- assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
- inputs_embeds = self.embed_tokens(input_ids)
- batch_size, seq_length = input_shape
- # initialize past_key_values
- return_legacy_cache = False
- return_self_attention_cache = False
- if use_cache or past_key_values is not None:
- if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
- return_self_attention_cache = True
- past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
- elif not isinstance(past_key_values, EncoderDecoderCache):
- return_legacy_cache = True
- logger.warning_once(
- "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
- "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
- "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
- )
- past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
- elif past_key_values is None:
- past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
- past_key_values_length = 0
- if cache_position is not None:
- past_key_values_length = cache_position[0]
- elif past_key_values is not None:
- past_key_values_length = past_key_values.get_seq_length()
- if cache_position is None:
- cache_position = torch.arange(
- past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
- )
- if attention_mask is None:
- # required mask seq length can be calculated via length of past
- mask_seq_length = (
- past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
- )
- attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
- if self.config.is_decoder:
- causal_mask = self._update_causal_mask(
- attention_mask,
- inputs_embeds,
- cache_position,
- past_key_values.self_attention_cache if past_key_values is not None else None,
- output_attentions,
- )
- else:
- causal_mask = attention_mask[:, None, None, :]
- causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
- causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- if encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_extended_attention_mask = None
- # Prepare head mask if needed
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
- cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
- all_hidden_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- all_cross_attentions = () if (output_attentions) else None
- position_bias = None
- encoder_decoder_position_bias = None
- hidden_states = self.dropout(inputs_embeds)
- for i, layer_module in enumerate(self.layer):
- layer_head_mask = head_mask[i]
- cross_attn_layer_head_mask = cross_attn_head_mask[i]
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- layer_outputs = self._gradient_checkpointing_func(
- layer_module.forward,
- hidden_states,
- causal_mask,
- position_bias,
- encoder_hidden_states,
- encoder_extended_attention_mask,
- encoder_decoder_position_bias,
- layer_head_mask,
- cross_attn_layer_head_mask,
- None, # past_key_value is always None with gradient checkpointing
- use_cache,
- output_attentions,
- cache_position,
- )
- else:
- layer_outputs = layer_module(
- hidden_states,
- attention_mask=causal_mask,
- position_bias=position_bias,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_extended_attention_mask,
- encoder_decoder_position_bias=encoder_decoder_position_bias,
- layer_head_mask=layer_head_mask,
- cross_attn_layer_head_mask=cross_attn_layer_head_mask,
- past_key_value=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- # layer_outputs is a tuple with:
- # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
- if use_cache is False:
- layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
- hidden_states, next_decoder_cache = layer_outputs[:2]
- # We share the position biases between the layers - the first layer store them
- # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
- # (cross-attention position bias), (cross-attention weights)
- position_bias = layer_outputs[2]
- if encoder_hidden_states is not None:
- encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[3],)
- if encoder_hidden_states is not None:
- all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.dropout(hidden_states)
- logits = self.lm_head(hidden_states)
- # Add last layer
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(logits.device)
- loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
- loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
- next_cache = next_decoder_cache if use_cache else None
- if return_self_attention_cache:
- next_cache = past_key_values.self_attention_cache
- if return_legacy_cache:
- next_cache = past_key_values.to_legacy_cache()
- if not return_dict:
- return tuple(
- v
- for v in [
- loss,
- logits,
- next_cache,
- all_hidden_states,
- all_attentions,
- all_cross_attentions,
- ]
- if v is not None
- )
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=logits,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
- return causal_mask
- @staticmethod
- # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
- @add_start_docstrings(
- "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
- PIX2STRUCT_START_DOCSTRING,
- )
- class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin):
- config_class = Pix2StructConfig
- main_input_name = "flattened_patches"
- _tied_weights_keys = ["decoder.lm_head.weight"]
- def __init__(self, config: Pix2StructConfig):
- super().__init__(config)
- self.encoder = Pix2StructVisionModel(config.vision_config)
- self.decoder = Pix2StructTextModel(config.text_config)
- self.is_vqa = config.is_vqa
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.decoder.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- self.decoder.set_input_embeddings(new_embeddings)
- def get_output_embeddings(self) -> nn.Module:
- return self.decoder.get_output_embeddings()
- def set_output_embeddings(self, new_embeddings):
- self.decoder.set_output_embeddings(new_embeddings)
- def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
- model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
- # update vocab size
- self.config.text_config.vocab_size = new_num_tokens
- return model_embeds
- def get_decoder(self):
- return self.decoder
- def get_encoder(self):
- return self.encoder
- @add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- flattened_patches: Optional[torch.FloatTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- decoder_input_ids: Optional[torch.LongTensor] = None,
- decoder_attention_mask: Optional[torch.BoolTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- decoder_head_mask: Optional[torch.FloatTensor] = None,
- cross_attn_head_mask: Optional[torch.Tensor] = None,
- encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
- labels: Optional[torch.LongTensor] = None,
- decoder_inputs_embeds: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
- r"""
- Returns:
- Example:
- Inference:
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
- >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
- >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> # autoregressive generation
- >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
- >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- >>> print(generated_text)
- A stop sign is on a street corner.
- >>> # conditional generation
- >>> text = "A picture of"
- >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)
- >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
- >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- >>> print(generated_text)
- A picture of a stop sign with a red stop sign
- ```
- Training:
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
- >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
- >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> text = "A stop sign is on the street corner."
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> labels = processor(text=text, return_tensors="pt").input_ids
- >>> # forward pass
- >>> outputs = model(**inputs, labels=labels)
- >>> loss = outputs.loss
- >>> print(f"{loss.item():.5f}")
- 5.94282
- ```"""
- use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # Encode if needed (training, first prediction pass)
- if encoder_outputs is None:
- encoder_outputs = self.encoder(
- flattened_patches=flattened_patches,
- attention_mask=attention_mask,
- head_mask=head_mask,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
- encoder_outputs = BaseModelOutput(
- last_hidden_state=encoder_outputs[0],
- hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
- attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
- )
- hidden_states = encoder_outputs[0]
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
- # get decoder inputs from shifting lm labels to the right
- decoder_input_ids = self._shift_right(labels)
- decoder_attention_mask = (
- decoder_attention_mask
- if decoder_attention_mask is not None
- else decoder_input_ids.ne(self.config.pad_token_id).float()
- )
- # Always attend to the first token
- decoder_attention_mask[:, 0] = 1
- # Decode
- decoder_outputs = self.decoder(
- input_ids=decoder_input_ids,
- attention_mask=decoder_attention_mask,
- inputs_embeds=decoder_inputs_embeds,
- past_key_values=past_key_values,
- encoder_hidden_states=hidden_states,
- encoder_attention_mask=attention_mask,
- head_mask=decoder_head_mask,
- cross_attn_head_mask=cross_attn_head_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- labels=labels,
- return_dict=return_dict,
- cache_position=cache_position,
- )
- if not return_dict:
- return decoder_outputs + encoder_outputs
- return Seq2SeqLMOutput(
- loss=decoder_outputs.loss,
- logits=decoder_outputs.logits,
- past_key_values=decoder_outputs.past_key_values,
- decoder_hidden_states=decoder_outputs.hidden_states,
- decoder_attentions=decoder_outputs.attentions,
- cross_attentions=decoder_outputs.cross_attentions,
- encoder_last_hidden_state=encoder_outputs.last_hidden_state,
- encoder_hidden_states=encoder_outputs.hidden_states,
- encoder_attentions=encoder_outputs.attentions,
- )
|