modeling_pix2struct.py 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pix2Struct modeling file"""
  16. import math
  17. from typing import Dict, List, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import AttentionMaskConverter
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. CausalLMOutputWithCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...pytorch_utils import ALL_LAYERNORM_LAYERS
  34. from ...utils import (
  35. DUMMY_INPUTS,
  36. DUMMY_MASK,
  37. add_start_docstrings,
  38. add_start_docstrings_to_model_forward,
  39. is_torch_fx_proxy,
  40. is_torchdynamo_compiling,
  41. logging,
  42. replace_return_docstrings,
  43. )
  44. from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
  45. logger = logging.get_logger(__name__)
  46. # General docstring
  47. _CONFIG_FOR_DOC = "Pix2StructConfig"
  48. # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct
  49. class Pix2StructLayerNorm(nn.Module):
  50. def __init__(self, hidden_size, eps=1e-6):
  51. """
  52. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  53. """
  54. super().__init__()
  55. self.weight = nn.Parameter(torch.ones(hidden_size))
  56. self.variance_epsilon = eps
  57. def forward(self, hidden_states):
  58. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  59. # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
  60. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  61. # half-precision inputs is done in fp32
  62. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  63. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  64. # convert into half-precision if necessary
  65. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  66. hidden_states = hidden_states.to(self.weight.dtype)
  67. return self.weight * hidden_states
  68. try:
  69. from apex.normalization import FusedRMSNorm
  70. Pix2StructLayerNorm = FusedRMSNorm # noqa
  71. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm")
  72. except ImportError:
  73. # using the normal Pix2StructLayerNorm
  74. pass
  75. except Exception:
  76. logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm")
  77. pass
  78. ALL_LAYERNORM_LAYERS.append(Pix2StructLayerNorm)
  79. class Pix2StructVisionEmbeddings(nn.Module):
  80. r"""
  81. Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
  82. Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
  83. is represented by a vector of `hidden_size` values.
  84. """
  85. def __init__(self, config: Pix2StructConfig) -> None:
  86. super().__init__()
  87. self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size)
  88. self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  89. self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  90. self.dropout = nn.Dropout(config.dropout_rate)
  91. def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
  92. # the row and column indices are stored in the first and second position of the flattened_patches
  93. # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2
  94. row_indices = flattened_patches[:, :, 0].long()
  95. col_indices = flattened_patches[:, :, 1].long()
  96. flattened_patches = flattened_patches[:, :, 2:]
  97. embeddings = self.patch_projection(flattened_patches)
  98. row_embeddings = self.row_embedder(row_indices)
  99. col_embeddings = self.column_embedder(col_indices)
  100. # sum all embeddings together
  101. embeddings = embeddings + row_embeddings + col_embeddings
  102. embeddings = self.dropout(embeddings)
  103. return embeddings
  104. class Pix2StructVisionAttention(nn.Module):
  105. def __init__(self, config):
  106. super().__init__()
  107. self.hidden_size = config.hidden_size
  108. self.key_value_proj_dim = config.d_kv
  109. self.n_heads = config.num_attention_heads
  110. self.dropout = config.attention_dropout
  111. self.inner_dim = self.n_heads * self.key_value_proj_dim
  112. # Mesh TensorFlow initialization to avoid scaling before softmax
  113. self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  114. self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  115. self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  116. self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
  117. self.gradient_checkpointing = False
  118. def forward(
  119. self,
  120. hidden_states,
  121. attention_mask=None,
  122. position_bias=None,
  123. layer_head_mask=None,
  124. output_attentions=False,
  125. ):
  126. """
  127. Self-attention block
  128. """
  129. # Input is (batch_size, seq_length, dim)
  130. # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
  131. # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
  132. batch_size, seq_length = hidden_states.shape[:2]
  133. def to_projection_shape(states):
  134. """projection"""
  135. return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  136. # get query states
  137. # (batch_size, n_heads, seq_length, dim_per_head)
  138. query_states = to_projection_shape(self.query(hidden_states))
  139. # get key/value states
  140. key_states = to_projection_shape(self.key(hidden_states))
  141. value_states = to_projection_shape(self.value(hidden_states))
  142. # compute scores
  143. # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  144. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  145. if position_bias is None:
  146. position_bias = torch.zeros(
  147. (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype
  148. )
  149. if self.gradient_checkpointing and self.training:
  150. position_bias.requires_grad = True
  151. if attention_mask.dim() == 2:
  152. position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
  153. elif attention_mask is not None:
  154. # (batch_size, n_heads, seq_length, key_length)
  155. position_bias = position_bias + attention_mask.to(position_bias.device)
  156. elif not is_torchdynamo_compiling():
  157. attention_mask = torch.ones(
  158. (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype
  159. )
  160. position_bias = position_bias + attention_mask.to(position_bias.device)
  161. position_bias = 1 - position_bias
  162. position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
  163. scores += position_bias_masked
  164. scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
  165. # (batch_size, n_heads, seq_length, key_length)
  166. attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores)
  167. # (batch_size, n_heads, seq_length, key_length)
  168. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  169. # Mask heads if we want to
  170. if layer_head_mask is not None:
  171. attn_weights = attn_weights * layer_head_mask
  172. attn_output = torch.matmul(attn_weights, value_states)
  173. # (batch_size, seq_length, dim)
  174. attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
  175. attn_output = self.output(attn_output)
  176. outputs = (attn_output,) + (position_bias,)
  177. if output_attentions:
  178. outputs = outputs + (attn_weights,)
  179. return outputs
  180. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate
  181. class Pix2StructVisionMlp(nn.Module):
  182. def __init__(self, config: Pix2StructVisionConfig):
  183. super().__init__()
  184. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  185. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  186. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  187. self.dropout = nn.Dropout(config.dropout_rate)
  188. self.act = ACT2FN[config.dense_act_fn]
  189. def forward(self, hidden_states):
  190. hidden_gelu = self.act(self.wi_0(hidden_states))
  191. hidden_linear = self.wi_1(hidden_states)
  192. hidden_states = hidden_gelu * hidden_linear
  193. hidden_states = self.dropout(hidden_states)
  194. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  195. # See https://github.com/huggingface/transformers/issues/20287
  196. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  197. if (
  198. isinstance(self.wo.weight, torch.Tensor)
  199. and hidden_states.dtype != self.wo.weight.dtype
  200. and self.wo.weight.dtype != torch.int8
  201. ):
  202. hidden_states = hidden_states.to(self.wo.weight.dtype)
  203. hidden_states = self.wo(hidden_states)
  204. return hidden_states
  205. class Pix2StructVisionLayer(nn.Module):
  206. def __init__(self, config: Pix2StructConfig) -> None:
  207. super().__init__()
  208. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  209. self.seq_len_dim = 1
  210. self.attention = Pix2StructVisionAttention(config)
  211. self.mlp = Pix2StructVisionMlp(config)
  212. self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  213. self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  214. def forward(
  215. self,
  216. hidden_states: torch.Tensor,
  217. attention_mask: Optional[torch.Tensor] = None,
  218. head_mask: Optional[torch.Tensor] = None,
  219. output_attentions: bool = False,
  220. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  221. residual = hidden_states
  222. # in Pix2StructVision, layernorm is applied before self-attention
  223. hidden_states = self.pre_attention_layer_norm(hidden_states)
  224. self_attention_outputs = self.attention(
  225. hidden_states,
  226. attention_mask=attention_mask,
  227. layer_head_mask=head_mask,
  228. output_attentions=output_attentions,
  229. )
  230. attention_output = self_attention_outputs[0]
  231. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  232. # first residual connection
  233. hidden_states = attention_output + residual
  234. # in Pix2StructVision, layernorm is also applied after self-attention
  235. layer_output = self.pre_mlp_layer_norm(hidden_states)
  236. layer_output = self.mlp(layer_output) + hidden_states # second residual connection
  237. outputs = (layer_output,) + outputs
  238. return outputs
  239. class Pix2StructVisionEncoder(nn.Module):
  240. def __init__(self, config: Pix2StructConfig) -> None:
  241. super().__init__()
  242. self.config = config
  243. self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)])
  244. self.gradient_checkpointing = False
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. attention_mask: Optional[torch.Tensor] = None,
  249. head_mask: Optional[torch.Tensor] = None,
  250. output_attentions: bool = False,
  251. output_hidden_states: bool = False,
  252. return_dict: bool = True,
  253. ) -> Union[tuple, BaseModelOutput]:
  254. all_hidden_states = () if output_hidden_states else None
  255. all_self_attentions = () if output_attentions else None
  256. for i, layer_module in enumerate(self.layer):
  257. if output_hidden_states:
  258. all_hidden_states = all_hidden_states + (hidden_states,)
  259. layer_head_mask = head_mask[i] if head_mask is not None else None
  260. if self.gradient_checkpointing and self.training:
  261. layer_outputs = self._gradient_checkpointing_func(
  262. layer_module.__call__,
  263. hidden_states,
  264. attention_mask,
  265. layer_head_mask,
  266. output_attentions,
  267. )
  268. else:
  269. layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
  270. hidden_states = layer_outputs[0]
  271. if output_attentions:
  272. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  273. if output_hidden_states:
  274. all_hidden_states = all_hidden_states + (hidden_states,)
  275. if not return_dict:
  276. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  277. return BaseModelOutput(
  278. last_hidden_state=hidden_states,
  279. hidden_states=all_hidden_states,
  280. attentions=all_self_attentions,
  281. )
  282. class Pix2StructPreTrainedModel(PreTrainedModel):
  283. """
  284. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  285. models.
  286. """
  287. config_class = Pix2StructConfig
  288. _supports_cache_class = True
  289. _supports_static_cache = False
  290. @property
  291. def dummy_inputs(self):
  292. input_ids = torch.tensor(DUMMY_INPUTS)
  293. input_mask = torch.tensor(DUMMY_MASK)
  294. dummy_inputs = {
  295. "decoder_input_ids": input_ids,
  296. "input_ids": input_ids,
  297. "decoder_attention_mask": input_mask,
  298. }
  299. return dummy_inputs
  300. def _init_weights(self, module):
  301. """Initialize the weights"""
  302. factor = self.config.initializer_factor # Used for testing weights initialization
  303. if isinstance(module, Pix2StructLayerNorm):
  304. module.weight.data.fill_(factor * 1.0)
  305. elif isinstance(module, Pix2StructTextDenseGatedActDense):
  306. hidden_size = (
  307. self.config.text_config.hidden_size
  308. if isinstance(self.config, Pix2StructConfig)
  309. else self.config.hidden_size
  310. )
  311. d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff
  312. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  313. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  314. module.wi_0.bias.data.zero_()
  315. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  316. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  317. module.wi_1.bias.data.zero_()
  318. module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
  319. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  320. module.wo.bias.data.zero_()
  321. elif isinstance(module, Pix2StructTextAttention):
  322. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  323. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  324. hidden_size = (
  325. self.config.text_config.hidden_size
  326. if isinstance(self.config, Pix2StructConfig)
  327. else self.config.hidden_size
  328. )
  329. key_value_proj_dim = (
  330. self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size
  331. )
  332. n_heads = (
  333. self.config.text_config.num_heads
  334. if isinstance(self.config, Pix2StructConfig)
  335. else self.config.num_heads
  336. )
  337. module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5))
  338. module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
  339. module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
  340. module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  341. if module.has_relative_attention_bias:
  342. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  343. elif isinstance(module, nn.Embedding):
  344. hidden_size = (
  345. self.config.text_config.hidden_size
  346. if isinstance(self.config, Pix2StructConfig)
  347. else self.config.hidden_size
  348. )
  349. module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  350. if module.padding_idx is not None:
  351. module.weight.data[module.padding_idx].zero_()
  352. elif isinstance(module, Pix2StructTextModel):
  353. hidden_size = (
  354. self.config.text_config.hidden_size
  355. if isinstance(self.config, Pix2StructConfig)
  356. else self.config.hidden_size
  357. )
  358. module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  359. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  360. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  361. # `trunc_normal_cpu` not implemented in `half` issues
  362. module.weight.data = nn.init.trunc_normal_(
  363. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  364. ).to(module.weight.dtype)
  365. if module.bias is not None:
  366. module.bias.data.zero_()
  367. elif isinstance(module, Pix2StructLayerNorm):
  368. if module.weight is not None:
  369. module.weight.data.fill_(1.0)
  370. elif isinstance(module, nn.Embedding):
  371. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  372. if module.padding_idx is not None:
  373. module.weight.data[module.padding_idx].zero_()
  374. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct
  375. def _shift_right(self, input_ids):
  376. decoder_start_token_id = self.config.decoder_start_token_id
  377. pad_token_id = self.config.pad_token_id
  378. if decoder_start_token_id is None:
  379. raise ValueError(
  380. "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. "
  381. "See Pix2Struct docs for more information."
  382. )
  383. # shift inputs to the right
  384. if is_torch_fx_proxy(input_ids):
  385. # Item assignment is not supported natively for proxies.
  386. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  387. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  388. else:
  389. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  390. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  391. shifted_input_ids[..., 0] = decoder_start_token_id
  392. if pad_token_id is None:
  393. raise ValueError("self.model.config.pad_token_id has to be defined.")
  394. # replace possible -100 values in labels by `pad_token_id`
  395. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  396. return shifted_input_ids
  397. PIX2STRUCT_VISION_START_DOCSTRING = r"""
  398. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  399. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  400. behavior.
  401. Parameters:
  402. config ([`Pix2StructConfig`]): Model configuration class with all the parameters of the model.
  403. Initializing with a config file does not load the weights associated with the model, only the
  404. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  405. """
  406. PIX2STRUCT_VISION_INPUTS_DOCSTRING = r"""
  407. Args:
  408. flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
  409. Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See
  410. [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original
  411. paper](https://arxiv.org/abs/2210.03347) (figure 5) for more details.
  412. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  413. Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
  414. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  415. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  416. - 1 indicates the head is **not masked**,
  417. - 0 indicates the head is **masked**.
  418. output_attentions (`bool`, *optional*):
  419. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  420. tensors for more detail.
  421. output_hidden_states (`bool`, *optional*):
  422. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  423. more detail.
  424. return_dict (`bool`, *optional*):
  425. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  426. """
  427. @add_start_docstrings(
  428. "The bare Pix2StructVision Model transformer outputting raw hidden-states without any specific head on top.",
  429. PIX2STRUCT_VISION_START_DOCSTRING,
  430. )
  431. class Pix2StructVisionModel(Pix2StructPreTrainedModel):
  432. config_class = Pix2StructVisionConfig
  433. main_input_name = "flattened_patches"
  434. supports_gradient_checkpointing = True
  435. _no_split_modules = ["Pix2StructVisionLayer"]
  436. def __init__(self, config: Pix2StructConfig):
  437. super().__init__(config)
  438. self.config = config
  439. self.embeddings = Pix2StructVisionEmbeddings(config)
  440. self.encoder = Pix2StructVisionEncoder(config)
  441. self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  442. # Initialize weights and apply final processing
  443. self.post_init()
  444. def get_input_embeddings(self):
  445. return self.embeddings.patch_projection
  446. def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
  447. """
  448. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  449. class PreTrainedModel
  450. """
  451. for layer, heads in heads_to_prune.items():
  452. self.encoder.layer[layer].attention.prune_heads(heads)
  453. @add_start_docstrings_to_model_forward(PIX2STRUCT_VISION_INPUTS_DOCSTRING)
  454. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
  455. def forward(
  456. self,
  457. flattened_patches: Optional[torch.Tensor] = None,
  458. attention_mask: Optional[torch.Tensor] = None,
  459. head_mask: Optional[torch.Tensor] = None,
  460. output_attentions: Optional[bool] = None,
  461. output_hidden_states: Optional[bool] = None,
  462. return_dict: Optional[bool] = None,
  463. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  464. r"""
  465. Returns:
  466. Example:
  467. ```python
  468. >>> import requests
  469. >>> from PIL import Image
  470. >>> from transformers import AutoProcessor, Pix2StructVisionModel
  471. >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  472. >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
  473. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  474. >>> image = Image.open(requests.get(url, stream=True).raw)
  475. >>> inputs = image_processor(images=image, return_tensors="pt")
  476. >>> with torch.no_grad():
  477. ... outputs = model(**inputs)
  478. >>> last_hidden_states = outputs.last_hidden_state
  479. >>> list(last_hidden_states.shape)
  480. [1, 2048, 768]
  481. ```
  482. """
  483. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  484. output_hidden_states = (
  485. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  486. )
  487. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  488. if flattened_patches is None:
  489. raise ValueError("You have to specify flattened_patches")
  490. if attention_mask is None:
  491. # check where `flattened_patches` is not 0
  492. attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
  493. # Prepare head mask if needed
  494. # 1.0 in head_mask indicate we keep the head
  495. # attention_probs has shape bsz x n_heads x N x N
  496. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  497. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  498. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  499. embedding_output = self.embeddings(flattened_patches)
  500. encoder_outputs = self.encoder(
  501. embedding_output,
  502. attention_mask=attention_mask,
  503. head_mask=head_mask,
  504. output_attentions=output_attentions,
  505. output_hidden_states=output_hidden_states,
  506. return_dict=return_dict,
  507. )
  508. sequence_output = encoder_outputs[0]
  509. sequence_output = self.layernorm(sequence_output)
  510. if not return_dict:
  511. head_outputs = (sequence_output,)
  512. return head_outputs + encoder_outputs[1:]
  513. return BaseModelOutput(
  514. last_hidden_state=sequence_output,
  515. hidden_states=encoder_outputs.hidden_states,
  516. attentions=encoder_outputs.attentions,
  517. )
  518. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size
  519. class Pix2StructTextDenseGatedActDense(nn.Module):
  520. def __init__(self, config: Pix2StructTextConfig):
  521. super().__init__()
  522. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  523. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  524. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  525. self.dropout = nn.Dropout(config.dropout_rate)
  526. self.act = ACT2FN[config.dense_act_fn]
  527. def forward(self, hidden_states):
  528. hidden_gelu = self.act(self.wi_0(hidden_states))
  529. hidden_linear = self.wi_1(hidden_states)
  530. hidden_states = hidden_gelu * hidden_linear
  531. hidden_states = self.dropout(hidden_states)
  532. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  533. # See https://github.com/huggingface/transformers/issues/20287
  534. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  535. if (
  536. isinstance(self.wo.weight, torch.Tensor)
  537. and hidden_states.dtype != self.wo.weight.dtype
  538. and self.wo.weight.dtype != torch.int8
  539. ):
  540. hidden_states = hidden_states.to(self.wo.weight.dtype)
  541. hidden_states = self.wo(hidden_states)
  542. return hidden_states
  543. class Pix2StructTextLayerFF(nn.Module):
  544. def __init__(self, config: Pix2StructTextConfig):
  545. super().__init__()
  546. self.DenseReluDense = Pix2StructTextDenseGatedActDense(config)
  547. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  548. self.dropout = nn.Dropout(config.dropout_rate)
  549. # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward
  550. def forward(self, hidden_states):
  551. forwarded_states = self.layer_norm(hidden_states)
  552. forwarded_states = self.DenseReluDense(forwarded_states)
  553. hidden_states = hidden_states + self.dropout(forwarded_states)
  554. return hidden_states
  555. class Pix2StructTextAttention(nn.Module):
  556. def __init__(
  557. self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None
  558. ):
  559. super().__init__()
  560. self.has_relative_attention_bias = has_relative_attention_bias
  561. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  562. self.relative_attention_max_distance = config.relative_attention_max_distance
  563. self.hidden_size = config.hidden_size
  564. self.key_value_proj_dim = config.d_kv
  565. self.n_heads = config.num_heads
  566. self.dropout = config.dropout_rate
  567. self.inner_dim = self.n_heads * self.key_value_proj_dim
  568. self.layer_idx = layer_idx
  569. if layer_idx is None:
  570. logger.warning_once(
  571. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  572. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  573. "when creating this class."
  574. )
  575. # Mesh TensorFlow initialization to avoid scaling before softmax
  576. self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  577. self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  578. self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  579. self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  580. if self.has_relative_attention_bias:
  581. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  582. self.pruned_heads = set()
  583. self.gradient_checkpointing = False
  584. @staticmethod
  585. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  586. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  587. """
  588. Adapted from Mesh Tensorflow:
  589. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  590. Translate relative position to a bucket number for relative attention. The relative position is defined as
  591. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  592. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  593. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  594. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  595. This should allow for more graceful generalization to longer sequences than the model has been trained on
  596. Args:
  597. relative_position: an int32 Tensor
  598. bidirectional: a boolean - whether the attention is bidirectional
  599. num_buckets: an integer
  600. max_distance: an integer
  601. Returns:
  602. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  603. """
  604. relative_buckets = 0
  605. if bidirectional:
  606. num_buckets //= 2
  607. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  608. relative_position = torch.abs(relative_position)
  609. else:
  610. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  611. # now relative_position is in the range [0, inf)
  612. # half of the buckets are for exact increments in positions
  613. max_exact = num_buckets // 2
  614. is_small = relative_position < max_exact
  615. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  616. relative_position_if_large = max_exact + (
  617. torch.log(relative_position.float() / max_exact)
  618. / math.log(max_distance / max_exact)
  619. * (num_buckets - max_exact)
  620. ).to(torch.long)
  621. relative_position_if_large = torch.min(
  622. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  623. )
  624. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  625. return relative_buckets
  626. # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
  627. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  628. """Compute binned relative position bias"""
  629. if device is None:
  630. device = self.relative_attention_bias.weight.device
  631. if cache_position is None:
  632. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  633. else:
  634. context_position = cache_position[:, None].to(device)
  635. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  636. relative_position = memory_position - context_position # shape (query_length, key_length)
  637. relative_position_bucket = self._relative_position_bucket(
  638. relative_position, # shape (query_length, key_length)
  639. bidirectional=False,
  640. num_buckets=self.relative_attention_num_buckets,
  641. max_distance=self.relative_attention_max_distance,
  642. )
  643. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  644. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  645. return values
  646. # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
  647. def forward(
  648. self,
  649. hidden_states,
  650. mask=None,
  651. key_value_states=None,
  652. position_bias=None,
  653. past_key_value=None,
  654. layer_head_mask=None,
  655. query_length=None,
  656. use_cache=False,
  657. output_attentions=False,
  658. cache_position=None,
  659. ):
  660. """
  661. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  662. """
  663. # Input is (batch_size, seq_length, dim)
  664. # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
  665. batch_size, seq_length = hidden_states.shape[:2]
  666. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  667. is_cross_attention = key_value_states is not None
  668. query_states = self.query(hidden_states)
  669. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  670. if past_key_value is not None:
  671. is_updated = past_key_value.is_updated.get(self.layer_idx)
  672. if is_cross_attention:
  673. # after the first generated id, we can subsequently re-use all key/value_states from cache
  674. curr_past_key_value = past_key_value.cross_attention_cache
  675. else:
  676. curr_past_key_value = past_key_value.self_attention_cache
  677. current_states = key_value_states if is_cross_attention else hidden_states
  678. if is_cross_attention and past_key_value and is_updated:
  679. # reuse k,v, cross_attentions
  680. key_states = curr_past_key_value.key_cache[self.layer_idx]
  681. value_states = curr_past_key_value.value_cache[self.layer_idx]
  682. else:
  683. key_states = self.key(current_states)
  684. value_states = self.value(current_states)
  685. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  686. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  687. if past_key_value is not None:
  688. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  689. cache_position = cache_position if not is_cross_attention else None
  690. key_states, value_states = curr_past_key_value.update(
  691. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  692. )
  693. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  694. if is_cross_attention:
  695. past_key_value.is_updated[self.layer_idx] = True
  696. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  697. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  698. if position_bias is None:
  699. key_length = key_states.shape[-2]
  700. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  701. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  702. if not self.has_relative_attention_bias:
  703. position_bias = torch.zeros(
  704. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  705. )
  706. if self.gradient_checkpointing and self.training:
  707. position_bias.requires_grad = True
  708. else:
  709. position_bias = self.compute_bias(
  710. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  711. )
  712. position_bias = position_bias[:, :, -seq_length:, :]
  713. if mask is not None:
  714. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  715. position_bias = position_bias + causal_mask
  716. if self.pruned_heads:
  717. mask = torch.ones(position_bias.shape[1])
  718. mask[list(self.pruned_heads)] = 0
  719. position_bias_masked = position_bias[:, mask.bool()]
  720. else:
  721. position_bias_masked = position_bias
  722. scores += position_bias_masked
  723. # (batch_size, n_heads, seq_length, key_length)
  724. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  725. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  726. # Mask heads if we want to
  727. if layer_head_mask is not None:
  728. attn_weights = attn_weights * layer_head_mask
  729. attn_output = torch.matmul(attn_weights, value_states)
  730. attn_output = attn_output.transpose(1, 2).contiguous()
  731. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  732. attn_output = self.output(attn_output)
  733. outputs = (attn_output, past_key_value, position_bias)
  734. if output_attentions:
  735. outputs = outputs + (attn_weights,)
  736. return outputs
  737. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size
  738. class Pix2StructTextLayerSelfAttention(nn.Module):
  739. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  740. super().__init__()
  741. self.attention = Pix2StructTextAttention(
  742. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  743. )
  744. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  745. self.dropout = nn.Dropout(config.dropout_rate)
  746. def forward(
  747. self,
  748. hidden_states,
  749. attention_mask=None,
  750. position_bias=None,
  751. layer_head_mask=None,
  752. past_key_value=None,
  753. use_cache=False,
  754. output_attentions=False,
  755. cache_position=None,
  756. ):
  757. normed_hidden_states = self.layer_norm(hidden_states)
  758. attention_output = self.attention(
  759. normed_hidden_states,
  760. mask=attention_mask,
  761. position_bias=position_bias,
  762. layer_head_mask=layer_head_mask,
  763. past_key_value=past_key_value,
  764. use_cache=use_cache,
  765. output_attentions=output_attentions,
  766. cache_position=cache_position,
  767. )
  768. hidden_states = hidden_states + self.dropout(attention_output[0])
  769. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  770. return outputs
  771. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size
  772. class Pix2StructTextLayerCrossAttention(nn.Module):
  773. def __init__(self, config, layer_idx: Optional[int] = None):
  774. super().__init__()
  775. self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  776. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  777. self.dropout = nn.Dropout(config.dropout_rate)
  778. def forward(
  779. self,
  780. hidden_states,
  781. key_value_states,
  782. attention_mask=None,
  783. position_bias=None,
  784. layer_head_mask=None,
  785. past_key_value=None,
  786. use_cache=False,
  787. query_length=None,
  788. output_attentions=False,
  789. cache_position=None,
  790. ):
  791. normed_hidden_states = self.layer_norm(hidden_states)
  792. attention_output = self.attention(
  793. normed_hidden_states,
  794. mask=attention_mask,
  795. key_value_states=key_value_states,
  796. position_bias=position_bias,
  797. layer_head_mask=layer_head_mask,
  798. past_key_value=past_key_value,
  799. use_cache=use_cache,
  800. query_length=query_length,
  801. output_attentions=output_attentions,
  802. cache_position=cache_position,
  803. )
  804. layer_output = hidden_states + self.dropout(attention_output[0])
  805. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  806. return outputs
  807. class Pix2StructTextBlock(nn.Module):
  808. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  809. super().__init__()
  810. self.self_attention = Pix2StructTextLayerSelfAttention(
  811. config,
  812. has_relative_attention_bias=has_relative_attention_bias,
  813. layer_idx=layer_idx,
  814. )
  815. self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
  816. config,
  817. layer_idx=layer_idx,
  818. )
  819. self.mlp = Pix2StructTextLayerFF(config)
  820. def forward(
  821. self,
  822. hidden_states,
  823. attention_mask=None,
  824. position_bias=None,
  825. encoder_hidden_states=None,
  826. encoder_attention_mask=None,
  827. encoder_decoder_position_bias=None,
  828. layer_head_mask=None,
  829. cross_attn_layer_head_mask=None,
  830. past_key_value=None,
  831. use_cache=False,
  832. output_attentions=False,
  833. return_dict=True,
  834. cache_position=None,
  835. ):
  836. self_attention_outputs = self.self_attention(
  837. hidden_states,
  838. attention_mask=attention_mask,
  839. position_bias=position_bias,
  840. layer_head_mask=layer_head_mask,
  841. past_key_value=past_key_value,
  842. use_cache=use_cache,
  843. output_attentions=output_attentions,
  844. cache_position=cache_position,
  845. )
  846. hidden_states, past_key_value = self_attention_outputs[:2]
  847. attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
  848. # clamp inf values to enable fp16 training
  849. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  850. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  851. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  852. do_cross_attention = encoder_hidden_states is not None
  853. if do_cross_attention:
  854. cross_attention_outputs = self.encoder_decoder_attention(
  855. hidden_states,
  856. key_value_states=encoder_hidden_states,
  857. attention_mask=encoder_attention_mask,
  858. position_bias=encoder_decoder_position_bias,
  859. layer_head_mask=cross_attn_layer_head_mask,
  860. past_key_value=past_key_value,
  861. query_length=cache_position[-1] + 1,
  862. use_cache=use_cache,
  863. output_attentions=output_attentions,
  864. )
  865. hidden_states, past_key_value = cross_attention_outputs[:2]
  866. # clamp inf values to enable fp16 training
  867. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  868. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  869. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  870. # Keep cross-attention outputs and relative position weights
  871. attention_outputs = attention_outputs + cross_attention_outputs[2:]
  872. # Apply Feed Forward layer
  873. hidden_states = self.mlp(hidden_states)
  874. # clamp inf values to enable fp16 training
  875. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  876. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  877. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  878. outputs = (hidden_states,)
  879. if use_cache:
  880. outputs = outputs + (past_key_value,) + attention_outputs
  881. else:
  882. outputs = outputs + attention_outputs
  883. return outputs
  884. PIX2STRUCT_START_DOCSTRING = r"""
  885. The Pix2Struct model was proposed in [Pix2Struct: Screenshot Parsing as Pretraining for Visual Language
  886. Understanding](https://arxiv.org/abs/2210.03347) by Kenton Lee, Mandar Joshi, Iulia Turc, Hexiang Hu, Fangyu Liu,
  887. Julian Eisenschlos, Urvashi Khandelwal, Peter Shaw, Ming-Wei Chang, Kristina Toutanova. It's an encoder decoder
  888. transformer pre-trained in a image-to-text setting.
  889. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  890. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  891. etc.)
  892. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  893. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  894. and behavior.
  895. Parameters:
  896. config (Union[`Pix2StructConfig`, `Pix2StructTextConfig`]):
  897. Model configuration class with all the parameters of the model. Initializing with a config file does not
  898. load the weights associated with the model, only the configuration. Check out the
  899. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  900. """
  901. PIX2STRUCT_TEXT_INPUTS_DOCSTRING = r"""
  902. Args:
  903. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  904. Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position
  905. embeddings so you should be able to pad the inputs on both the right and the left.
  906. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  907. [`PreTrainedTokenizer.__call__`] for detail.
  908. [What are input IDs?](../glossary#input-ids)
  909. To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText
  910. Training](./t5#training).
  911. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  912. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  913. - 1 for tokens that are **not masked**,
  914. - 0 for tokens that are **masked**.
  915. [What are attention masks?](../glossary#attention-mask)
  916. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  917. Indices of decoder input sequence tokens in the vocabulary.
  918. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  919. [`PreTrainedTokenizer.__call__`] for details.
  920. [What are decoder input IDs?](../glossary#decoder-input-ids)
  921. Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  922. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  923. `past_key_values`).
  924. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
  925. Training](./t5#training).
  926. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  927. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  928. be used by default.
  929. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  930. Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
  931. 1]`:
  932. - 1 indicates the head is **not masked**,
  933. - 0 indicates the head is **masked**.
  934. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  935. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  936. 1]`:
  937. - 1 indicates the head is **not masked**,
  938. - 0 indicates the head is **masked**.
  939. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  940. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  941. `[0, 1]`:
  942. - 1 indicates the head is **not masked**,
  943. - 0 indicates the head is **masked**.
  944. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  945. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
  946. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
  947. the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  948. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  949. Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding.
  950. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  951. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  952. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  953. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  954. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  955. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  956. model's internal embedding lookup matrix.
  957. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  958. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  959. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  960. input (see `past_key_values`). This is useful if you want more control over how to convert
  961. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  962. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  963. of `inputs_embeds`.
  964. use_cache (`bool`, *optional*):
  965. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  966. `past_key_values`).
  967. output_attentions (`bool`, *optional*):
  968. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  969. tensors for more detail.
  970. output_hidden_states (`bool`, *optional*):
  971. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  972. more detail.
  973. return_dict (`bool`, *optional*):
  974. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  975. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  976. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  977. cache in the correct position and to infer the complete sequence length.
  978. """
  979. PIX2STRUCT_INPUTS_DOCSTRING = r"""
  980. Args:
  981. flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
  982. Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
  983. `num_channels` * `patch_size` * `patch_size`
  984. The process of flattening the pixel patches is done by `Pix2StructProcessor`.
  985. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  986. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  987. - 1 for tokens that are **not masked**,
  988. - 0 for tokens that are **masked**.
  989. [What are attention masks?](../glossary#attention-mask)
  990. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  991. Indices of decoder input sequence tokens in the vocabulary.
  992. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  993. [`PreTrainedTokenizer.__call__`] for details.
  994. [What are decoder input IDs?](../glossary#decoder-input-ids)
  995. Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  996. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  997. `past_key_values`).
  998. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
  999. Training](./t5#training).
  1000. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1001. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1002. be used by default.
  1003. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1004. Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0,
  1005. 1]`:
  1006. - 1 indicates the head is **not masked**,
  1007. - 0 indicates the head is **masked**.
  1008. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1009. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1010. 1]`:
  1011. - 1 indicates the head is **not masked**,
  1012. - 0 indicates the head is **masked**.
  1013. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1014. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1015. `[0, 1]`:
  1016. - 1 indicates the head is **not masked**,
  1017. - 0 indicates the head is **masked**.
  1018. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  1019. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*)
  1020. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at
  1021. the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  1022. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1023. Contains precomputed key and value hidden states of the attention layers. Can be used to speed up decoding.
  1024. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  1025. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  1026. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1027. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*):
  1028. Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded
  1029. representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be
  1030. input (see `past_key_values`). This is useful if you want more control over how to convert
  1031. `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  1032. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
  1033. of `inputs_embeds`.
  1034. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1035. Labels for computing the masked language modeling loss for the decoder.
  1036. use_cache (`bool`, *optional*):
  1037. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1038. `past_key_values`).
  1039. output_attentions (`bool`, *optional*):
  1040. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1041. tensors for more detail.
  1042. output_hidden_states (`bool`, *optional*):
  1043. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1044. more detail.
  1045. return_dict (`bool`, *optional*):
  1046. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1047. """
  1048. @add_start_docstrings(
  1049. "The standalone text decoder of Pix2Struct",
  1050. PIX2STRUCT_START_DOCSTRING,
  1051. )
  1052. class Pix2StructTextModel(Pix2StructPreTrainedModel):
  1053. config_class = Pix2StructTextConfig
  1054. _no_split_modules = ["Pix2StructTextBlock"]
  1055. _tied_weights_keys = ["lm_head.weight"]
  1056. supports_gradient_checkpointing = True
  1057. def __init__(self, config):
  1058. super().__init__(config)
  1059. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  1060. self.layer = nn.ModuleList(
  1061. [
  1062. Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  1063. for i in range(config.num_layers)
  1064. ]
  1065. )
  1066. self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  1067. self.dropout = nn.Dropout(config.dropout_rate)
  1068. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1069. # Initialize weights and apply final processing
  1070. self.post_init()
  1071. self.gradient_checkpointing = False
  1072. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._reorder_cache
  1073. def _reorder_cache(self, past_key_values, beam_idx):
  1074. # if decoder past is not included in output
  1075. # speedy decoding is disabled and no need to reorder
  1076. if past_key_values is None:
  1077. logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
  1078. return past_key_values
  1079. reordered_decoder_past = ()
  1080. for layer_past_states in past_key_values:
  1081. # get the correct batch idx from layer past batch dim
  1082. # batch dim of `past` is at 2nd position
  1083. reordered_layer_past_states = ()
  1084. for layer_past_state in layer_past_states:
  1085. # need to set correct `past` for each of the four key / value states
  1086. reordered_layer_past_states = reordered_layer_past_states + (
  1087. layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
  1088. )
  1089. if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
  1090. raise ValueError(
  1091. f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
  1092. )
  1093. if len(reordered_layer_past_states) != len(layer_past_states):
  1094. raise ValueError(
  1095. f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
  1096. )
  1097. reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
  1098. return reordered_decoder_past
  1099. def get_input_embeddings(self):
  1100. return self.embed_tokens
  1101. def set_input_embeddings(self, new_embeddings):
  1102. self.embed_tokens = new_embeddings
  1103. def get_output_embeddings(self):
  1104. return self.lm_head
  1105. def set_output_embeddings(self, new_embeddings):
  1106. self.lm_head = new_embeddings
  1107. @add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
  1108. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
  1109. def forward(
  1110. self,
  1111. input_ids: Optional[torch.LongTensor] = None,
  1112. attention_mask: Optional[torch.FloatTensor] = None,
  1113. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1114. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  1115. inputs_embeds: Optional[torch.LongTensor] = None,
  1116. head_mask: Optional[torch.FloatTensor] = None,
  1117. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1118. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1119. use_cache: Optional[bool] = None,
  1120. output_attentions: Optional[bool] = None,
  1121. output_hidden_states: Optional[bool] = None,
  1122. labels: Optional[torch.LongTensor] = None,
  1123. return_dict: Optional[bool] = None,
  1124. cache_position: Optional[torch.LongTensor] = None,
  1125. **kwargs,
  1126. ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
  1127. r"""
  1128. Returns:
  1129. Example:
  1130. ```python
  1131. >>> from transformers import AutoProcessor, Pix2StructTextModel
  1132. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  1133. >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")
  1134. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  1135. >>> outputs = model(**inputs)
  1136. >>> loss = outputs.loss
  1137. ```
  1138. """
  1139. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1140. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1141. output_hidden_states = (
  1142. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1143. )
  1144. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1145. if input_ids is not None and inputs_embeds is not None:
  1146. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  1147. elif input_ids is not None:
  1148. input_shape = input_ids.size()
  1149. input_ids = input_ids.view(-1, input_shape[-1])
  1150. elif inputs_embeds is not None:
  1151. input_shape = inputs_embeds.size()[:-1]
  1152. else:
  1153. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  1154. if inputs_embeds is None:
  1155. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  1156. inputs_embeds = self.embed_tokens(input_ids)
  1157. batch_size, seq_length = input_shape
  1158. # initialize past_key_values
  1159. return_legacy_cache = False
  1160. return_self_attention_cache = False
  1161. if use_cache or past_key_values is not None:
  1162. if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
  1163. return_self_attention_cache = True
  1164. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
  1165. elif not isinstance(past_key_values, EncoderDecoderCache):
  1166. return_legacy_cache = True
  1167. logger.warning_once(
  1168. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
  1169. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  1170. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  1171. )
  1172. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  1173. elif past_key_values is None:
  1174. past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
  1175. past_key_values_length = 0
  1176. if cache_position is not None:
  1177. past_key_values_length = cache_position[0]
  1178. elif past_key_values is not None:
  1179. past_key_values_length = past_key_values.get_seq_length()
  1180. if cache_position is None:
  1181. cache_position = torch.arange(
  1182. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  1183. )
  1184. if attention_mask is None:
  1185. # required mask seq length can be calculated via length of past
  1186. mask_seq_length = (
  1187. past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
  1188. )
  1189. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  1190. if self.config.is_decoder:
  1191. causal_mask = self._update_causal_mask(
  1192. attention_mask,
  1193. inputs_embeds,
  1194. cache_position,
  1195. past_key_values.self_attention_cache if past_key_values is not None else None,
  1196. output_attentions,
  1197. )
  1198. else:
  1199. causal_mask = attention_mask[:, None, None, :]
  1200. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  1201. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  1202. # If a 2D or 3D attention mask is provided for the cross-attention
  1203. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1204. if encoder_hidden_states is not None:
  1205. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1206. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1207. if encoder_attention_mask is None:
  1208. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  1209. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1210. else:
  1211. encoder_extended_attention_mask = None
  1212. # Prepare head mask if needed
  1213. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  1214. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  1215. all_hidden_states = () if output_hidden_states else None
  1216. all_attentions = () if output_attentions else None
  1217. all_cross_attentions = () if (output_attentions) else None
  1218. position_bias = None
  1219. encoder_decoder_position_bias = None
  1220. hidden_states = self.dropout(inputs_embeds)
  1221. for i, layer_module in enumerate(self.layer):
  1222. layer_head_mask = head_mask[i]
  1223. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  1224. if output_hidden_states:
  1225. all_hidden_states = all_hidden_states + (hidden_states,)
  1226. if self.gradient_checkpointing and self.training:
  1227. if use_cache:
  1228. logger.warning(
  1229. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1230. )
  1231. use_cache = False
  1232. layer_outputs = self._gradient_checkpointing_func(
  1233. layer_module.forward,
  1234. hidden_states,
  1235. causal_mask,
  1236. position_bias,
  1237. encoder_hidden_states,
  1238. encoder_extended_attention_mask,
  1239. encoder_decoder_position_bias,
  1240. layer_head_mask,
  1241. cross_attn_layer_head_mask,
  1242. None, # past_key_value is always None with gradient checkpointing
  1243. use_cache,
  1244. output_attentions,
  1245. cache_position,
  1246. )
  1247. else:
  1248. layer_outputs = layer_module(
  1249. hidden_states,
  1250. attention_mask=causal_mask,
  1251. position_bias=position_bias,
  1252. encoder_hidden_states=encoder_hidden_states,
  1253. encoder_attention_mask=encoder_extended_attention_mask,
  1254. encoder_decoder_position_bias=encoder_decoder_position_bias,
  1255. layer_head_mask=layer_head_mask,
  1256. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  1257. past_key_value=past_key_values,
  1258. use_cache=use_cache,
  1259. output_attentions=output_attentions,
  1260. cache_position=cache_position,
  1261. )
  1262. # layer_outputs is a tuple with:
  1263. # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1264. if use_cache is False:
  1265. layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
  1266. hidden_states, next_decoder_cache = layer_outputs[:2]
  1267. # We share the position biases between the layers - the first layer store them
  1268. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1269. # (cross-attention position bias), (cross-attention weights)
  1270. position_bias = layer_outputs[2]
  1271. if encoder_hidden_states is not None:
  1272. encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
  1273. if output_attentions:
  1274. all_attentions = all_attentions + (layer_outputs[3],)
  1275. if encoder_hidden_states is not None:
  1276. all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
  1277. hidden_states = self.final_layer_norm(hidden_states)
  1278. hidden_states = self.dropout(hidden_states)
  1279. logits = self.lm_head(hidden_states)
  1280. # Add last layer
  1281. if output_hidden_states:
  1282. all_hidden_states = all_hidden_states + (hidden_states,)
  1283. loss = None
  1284. if labels is not None:
  1285. # move labels to correct device to enable model parallelism
  1286. labels = labels.to(logits.device)
  1287. loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
  1288. loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
  1289. next_cache = next_decoder_cache if use_cache else None
  1290. if return_self_attention_cache:
  1291. next_cache = past_key_values.self_attention_cache
  1292. if return_legacy_cache:
  1293. next_cache = past_key_values.to_legacy_cache()
  1294. if not return_dict:
  1295. return tuple(
  1296. v
  1297. for v in [
  1298. loss,
  1299. logits,
  1300. next_cache,
  1301. all_hidden_states,
  1302. all_attentions,
  1303. all_cross_attentions,
  1304. ]
  1305. if v is not None
  1306. )
  1307. return CausalLMOutputWithCrossAttentions(
  1308. loss=loss,
  1309. logits=logits,
  1310. past_key_values=next_cache,
  1311. hidden_states=all_hidden_states,
  1312. attentions=all_attentions,
  1313. cross_attentions=all_cross_attentions,
  1314. )
  1315. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  1316. def _update_causal_mask(
  1317. self,
  1318. attention_mask: torch.Tensor,
  1319. input_tensor: torch.Tensor,
  1320. cache_position: torch.Tensor,
  1321. past_key_values: Cache,
  1322. output_attentions: bool,
  1323. ):
  1324. if self.config._attn_implementation == "flash_attention_2":
  1325. if attention_mask is not None and 0.0 in attention_mask:
  1326. return attention_mask
  1327. return None
  1328. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1329. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1330. # to infer the attention mask.
  1331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1332. using_static_cache = isinstance(past_key_values, StaticCache)
  1333. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1334. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  1335. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1336. attention_mask,
  1337. inputs_embeds=input_tensor,
  1338. past_key_values_length=past_seen_tokens,
  1339. is_training=self.training,
  1340. ):
  1341. return None
  1342. dtype, device = input_tensor.dtype, input_tensor.device
  1343. sequence_length = input_tensor.shape[1]
  1344. if using_static_cache:
  1345. target_length = past_key_values.get_max_cache_shape()
  1346. else:
  1347. target_length = (
  1348. attention_mask.shape[-1]
  1349. if isinstance(attention_mask, torch.Tensor)
  1350. else past_seen_tokens + sequence_length + 1
  1351. )
  1352. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1353. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1354. attention_mask,
  1355. sequence_length=sequence_length,
  1356. target_length=target_length,
  1357. dtype=dtype,
  1358. device=device,
  1359. cache_position=cache_position,
  1360. batch_size=input_tensor.shape[0],
  1361. )
  1362. if (
  1363. self.config._attn_implementation == "sdpa"
  1364. and attention_mask is not None
  1365. and attention_mask.device.type == "cuda"
  1366. and not output_attentions
  1367. ):
  1368. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1369. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1370. # Details: https://github.com/pytorch/pytorch/issues/110213
  1371. min_dtype = torch.finfo(dtype).min
  1372. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1373. return causal_mask
  1374. @staticmethod
  1375. # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
  1376. def _prepare_4d_causal_attention_mask_with_cache_position(
  1377. attention_mask: torch.Tensor,
  1378. sequence_length: int,
  1379. target_length: int,
  1380. dtype: torch.dtype,
  1381. device: torch.device,
  1382. cache_position: torch.Tensor,
  1383. batch_size: int,
  1384. **kwargs,
  1385. ):
  1386. """
  1387. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1388. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1389. Args:
  1390. attention_mask (`torch.Tensor`):
  1391. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1392. `(batch_size, 1, query_length, key_value_length)`.
  1393. sequence_length (`int`):
  1394. The sequence length being processed.
  1395. target_length (`int`):
  1396. The target length: when generating with static cache, the mask should be as long as the static cache,
  1397. to account for the 0 padding, the part of the cache that is not filled yet.
  1398. dtype (`torch.dtype`):
  1399. The dtype to use for the 4D attention mask.
  1400. device (`torch.device`):
  1401. The device to plcae the 4D attention mask on.
  1402. cache_position (`torch.Tensor`):
  1403. Indices depicting the position of the input sequence tokens in the sequence.
  1404. batch_size (`torch.Tensor`):
  1405. Batch size.
  1406. """
  1407. if attention_mask is not None and attention_mask.dim() == 4:
  1408. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1409. causal_mask = attention_mask
  1410. else:
  1411. min_dtype = torch.finfo(dtype).min
  1412. causal_mask = torch.full(
  1413. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1414. )
  1415. if sequence_length != 1:
  1416. causal_mask = torch.triu(causal_mask, diagonal=1)
  1417. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1418. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1419. if attention_mask is not None:
  1420. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1421. mask_length = attention_mask.shape[-1]
  1422. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1423. padding_mask = padding_mask == 0
  1424. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1425. padding_mask, min_dtype
  1426. )
  1427. return causal_mask
  1428. @add_start_docstrings(
  1429. "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
  1430. PIX2STRUCT_START_DOCSTRING,
  1431. )
  1432. class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin):
  1433. config_class = Pix2StructConfig
  1434. main_input_name = "flattened_patches"
  1435. _tied_weights_keys = ["decoder.lm_head.weight"]
  1436. def __init__(self, config: Pix2StructConfig):
  1437. super().__init__(config)
  1438. self.encoder = Pix2StructVisionModel(config.vision_config)
  1439. self.decoder = Pix2StructTextModel(config.text_config)
  1440. self.is_vqa = config.is_vqa
  1441. # Initialize weights and apply final processing
  1442. self.post_init()
  1443. def get_input_embeddings(self):
  1444. return self.decoder.get_input_embeddings()
  1445. def set_input_embeddings(self, new_embeddings):
  1446. self.decoder.set_input_embeddings(new_embeddings)
  1447. def get_output_embeddings(self) -> nn.Module:
  1448. return self.decoder.get_output_embeddings()
  1449. def set_output_embeddings(self, new_embeddings):
  1450. self.decoder.set_output_embeddings(new_embeddings)
  1451. def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
  1452. model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
  1453. # update vocab size
  1454. self.config.text_config.vocab_size = new_num_tokens
  1455. return model_embeds
  1456. def get_decoder(self):
  1457. return self.decoder
  1458. def get_encoder(self):
  1459. return self.encoder
  1460. @add_start_docstrings_to_model_forward(PIX2STRUCT_INPUTS_DOCSTRING)
  1461. @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
  1462. def forward(
  1463. self,
  1464. flattened_patches: Optional[torch.FloatTensor] = None,
  1465. attention_mask: Optional[torch.FloatTensor] = None,
  1466. decoder_input_ids: Optional[torch.LongTensor] = None,
  1467. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1468. head_mask: Optional[torch.FloatTensor] = None,
  1469. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1470. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1471. encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1472. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1473. labels: Optional[torch.LongTensor] = None,
  1474. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1475. use_cache: Optional[bool] = None,
  1476. output_attentions: Optional[bool] = None,
  1477. output_hidden_states: Optional[bool] = None,
  1478. return_dict: Optional[bool] = None,
  1479. cache_position: Optional[torch.LongTensor] = None,
  1480. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1481. r"""
  1482. Returns:
  1483. Example:
  1484. Inference:
  1485. ```python
  1486. >>> from PIL import Image
  1487. >>> import requests
  1488. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1489. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  1490. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
  1491. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1492. >>> image = Image.open(requests.get(url, stream=True).raw)
  1493. >>> inputs = processor(images=image, return_tensors="pt")
  1494. >>> # autoregressive generation
  1495. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1496. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1497. >>> print(generated_text)
  1498. A stop sign is on a street corner.
  1499. >>> # conditional generation
  1500. >>> text = "A picture of"
  1501. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)
  1502. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1503. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1504. >>> print(generated_text)
  1505. A picture of a stop sign with a red stop sign
  1506. ```
  1507. Training:
  1508. ```python
  1509. >>> from PIL import Image
  1510. >>> import requests
  1511. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1512. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
  1513. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
  1514. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1515. >>> image = Image.open(requests.get(url, stream=True).raw)
  1516. >>> text = "A stop sign is on the street corner."
  1517. >>> inputs = processor(images=image, return_tensors="pt")
  1518. >>> labels = processor(text=text, return_tensors="pt").input_ids
  1519. >>> # forward pass
  1520. >>> outputs = model(**inputs, labels=labels)
  1521. >>> loss = outputs.loss
  1522. >>> print(f"{loss.item():.5f}")
  1523. 5.94282
  1524. ```"""
  1525. use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
  1526. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1527. # Encode if needed (training, first prediction pass)
  1528. if encoder_outputs is None:
  1529. encoder_outputs = self.encoder(
  1530. flattened_patches=flattened_patches,
  1531. attention_mask=attention_mask,
  1532. head_mask=head_mask,
  1533. output_attentions=output_attentions,
  1534. output_hidden_states=output_hidden_states,
  1535. return_dict=return_dict,
  1536. )
  1537. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1538. encoder_outputs = BaseModelOutput(
  1539. last_hidden_state=encoder_outputs[0],
  1540. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1541. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1542. )
  1543. hidden_states = encoder_outputs[0]
  1544. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1545. # get decoder inputs from shifting lm labels to the right
  1546. decoder_input_ids = self._shift_right(labels)
  1547. decoder_attention_mask = (
  1548. decoder_attention_mask
  1549. if decoder_attention_mask is not None
  1550. else decoder_input_ids.ne(self.config.pad_token_id).float()
  1551. )
  1552. # Always attend to the first token
  1553. decoder_attention_mask[:, 0] = 1
  1554. # Decode
  1555. decoder_outputs = self.decoder(
  1556. input_ids=decoder_input_ids,
  1557. attention_mask=decoder_attention_mask,
  1558. inputs_embeds=decoder_inputs_embeds,
  1559. past_key_values=past_key_values,
  1560. encoder_hidden_states=hidden_states,
  1561. encoder_attention_mask=attention_mask,
  1562. head_mask=decoder_head_mask,
  1563. cross_attn_head_mask=cross_attn_head_mask,
  1564. use_cache=use_cache,
  1565. output_attentions=output_attentions,
  1566. output_hidden_states=output_hidden_states,
  1567. labels=labels,
  1568. return_dict=return_dict,
  1569. cache_position=cache_position,
  1570. )
  1571. if not return_dict:
  1572. return decoder_outputs + encoder_outputs
  1573. return Seq2SeqLMOutput(
  1574. loss=decoder_outputs.loss,
  1575. logits=decoder_outputs.logits,
  1576. past_key_values=decoder_outputs.past_key_values,
  1577. decoder_hidden_states=decoder_outputs.hidden_states,
  1578. decoder_attentions=decoder_outputs.attentions,
  1579. cross_attentions=decoder_outputs.cross_attentions,
  1580. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1581. encoder_hidden_states=encoder_outputs.hidden_states,
  1582. encoder_attentions=encoder_outputs.attentions,
  1583. )