| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750 |
- # coding=utf-8
- # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Idefics model."""
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- import torch.nn.functional as F
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, StaticCache
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import AttentionMaskConverter
- from ...modeling_outputs import ModelOutput
- from ...modeling_utils import PretrainedConfig, PreTrainedModel
- from ...pytorch_utils import ALL_LAYERNORM_LAYERS
- from ...utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from .configuration_idefics import IdeficsConfig
- from .perceiver import IdeficsPerceiverResampler
- from .vision import IdeficsVisionTransformer
- logger = logging.get_logger(__name__)
- _CONFIG_FOR_DOC = "IdeficsConfig"
- @dataclass
- class IdeficsBaseModelOutputWithPast(ModelOutput):
- """
- Base class for Idefics model's outputs that may also contain a past key/values (to speed up sequential decoding).
- Args:
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
- hidden_size)` is output.
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
- `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
- encoder_sequence_length, embed_size_per_head)`.
- Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
- `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
- input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
- """
- last_hidden_state: torch.FloatTensor = None
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- @dataclass
- class IdeficsCausalLMOutputWithPast(ModelOutput):
- """
- Base class for Idefics causal language model (or autoregressive) outputs.
- Args:
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
- `past_key_values` input) to speed up sequential decoding.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
- Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
- sequence_length, hidden_size)`.
- image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
- """
- loss: Optional[torch.FloatTensor] = None
- logits: torch.FloatTensor = None
- past_key_values: Optional[List[torch.FloatTensor]] = None
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- attentions: Optional[Tuple[torch.FloatTensor]] = None
- image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
- def expand_inputs_for_generation(
- input_ids,
- expand_size=1,
- is_encoder_decoder=False,
- attention_mask=None,
- encoder_outputs=None,
- **model_kwargs,
- ):
- expanded_return_idx = (
- torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
- )
- input_ids = input_ids.index_select(0, expanded_return_idx)
- model_kwargs["pixel_values"] = model_kwargs.get("pixel_values", None)
- model_kwargs["image_encoder_embeddings"] = model_kwargs.get("image_encoder_embeddings", None)
- model_kwargs["perceiver_embeddings"] = model_kwargs.get("perceiver_embeddings", None)
- model_kwargs["image_attention_mask"] = model_kwargs.get("image_attention_mask", None)
- if "token_type_ids" in model_kwargs:
- token_type_ids = model_kwargs["token_type_ids"]
- model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
- if attention_mask is not None:
- model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
- if model_kwargs["image_attention_mask"] is not None:
- model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
- 0, expanded_return_idx
- )
- if model_kwargs["pixel_values"] is not None:
- model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
- elif model_kwargs["image_encoder_embeddings"] is not None:
- model_kwargs["image_encoder_embeddings"] = model_kwargs["image_encoder_embeddings"].index_select(
- 0, expanded_return_idx
- )
- elif model_kwargs["perceiver_embeddings"] is not None:
- model_kwargs["perceiver_embeddings"] = model_kwargs["perceiver_embeddings"].index_select(
- 0, expanded_return_idx
- )
- return input_ids, model_kwargs
- def freeze_model(model, module_exceptions=[]):
- mapping = {
- "LayerNorm": nn.LayerNorm,
- "Linear": nn.Linear,
- "Embedding": nn.Embedding,
- }
- module_exceptions_mapped = [mapping[m] for m in module_exceptions]
- for module in model.modules():
- if module_exceptions and any(isinstance(module, t) for t in module_exceptions_mapped):
- module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes
- else:
- module.requires_grad_(False)
- return model
- class IdeficsDecoupledEmbedding(nn.Embedding):
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
- """
- Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
- then it will create `num_additional_embeddings` additional parameters that are always trained. If
- `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
- """
- def __init__(
- self,
- num_embeddings,
- num_additional_embeddings,
- embedding_dim,
- partially_freeze: Optional[bool] = False,
- device=None,
- dtype=None,
- padding_idx=None,
- **kwargs,
- ) -> None:
- """
- Args:
- num_embeddings (`int`):
- Size of the dictionary of embeddings
- num_additional_embeddings (`int`):
- Number of additional embeddings. Only useful when you `partially_freeze=True`.
- embedding_dim (`int`):
- The size of each embedding vector
- partially_freeze: (`bool`, *optional*, defaults to `False`):
- If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
- padding_idx (`int`, *optional*):
- The padding index (needs to be less than num_embeddings)
- Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
- `max_norm` or `norm_type`. We are not supporting these.
- """
- if padding_idx is not None and padding_idx > num_embeddings:
- raise ValueError(f"padding_idx must be within num_embeddings. Got {padding_idx} and {num_embeddings}")
- super().__init__(
- num_embeddings=num_embeddings,
- embedding_dim=embedding_dim,
- device=device,
- dtype=dtype,
- padding_idx=padding_idx,
- **kwargs,
- )
- self.num_embeddings = num_embeddings
- self.padding_idx = padding_idx
- self.num_additional_embeddings = num_additional_embeddings
- self.partially_freeze = partially_freeze
- if partially_freeze:
- self.weight.requires_grad_(False)
- if self.num_additional_embeddings > 0:
- self.additional_embedding = nn.Embedding(
- num_embeddings=self.num_additional_embeddings,
- embedding_dim=embedding_dim,
- device=device,
- dtype=dtype,
- )
- def forward(self, input_ids):
- """
- we have 2 embeddings, with different indices - one pretrained self.weight and another
- self.additional_embedding.weight that is being trained.
- in order to make a lookup of the input ids, we:
- 1. find out the indices of the entries belonging to the 2nd embedding
- 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
- embedding starts from 0 and not num_embeddings
- 3. perform the 2nd embedding lookup
- 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
- 5. perform the 1st embedding lookup
- 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
- note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
- then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
- i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
- usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
- measure.
- """
- if self.num_additional_embeddings == 0:
- return F.embedding(input_ids, self.weight)
- # Clone so that we don't modify the original input_ids later on
- input_ids = input_ids.clone()
- additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
- input_ids_additional_vocab = input_ids[additional_vocab_indices]
- additional_embeddings = self.additional_embedding(input_ids_additional_vocab - self.num_embeddings)
- # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
- input_ids[additional_vocab_indices] = 0
- full_vector = F.embedding(input_ids, self.weight)
- # overwrite the records with high indices
- full_vector[additional_vocab_indices] = additional_embeddings
- return full_vector
- def extra_repr(self) -> str:
- return "num_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
- self.num_embeddings,
- self.num_additional_embeddings,
- self.embedding_dim,
- self.partially_freeze,
- )
- class IdeficsDecoupledLinear(nn.Linear):
- # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
- """
- Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
- regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `out_additional_features` > 0,
- then it will create `out_additional_features * in_features` additional parameters that are always trained. If
- `out_additional_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
- """
- def __init__(
- self,
- in_features: int,
- out_features: int,
- out_additional_features: int = 0,
- bias: bool = True,
- partially_freeze: bool = True,
- device=None,
- dtype=None,
- ) -> None:
- """
- out_additional_features: int. Number of additional trainable dimensions. Only makes sense when
- `partially_freeze=True`. partially_freeze: bool. If True, the regular `weight` will be frozen and extra
- parameters (if any) will be trainable. If False, default to the regular behavior of nn.Linear.
- """
- super().__init__(in_features, out_features, bias, device, dtype)
- self.out_additional_features = out_additional_features
- self.partially_freeze = partially_freeze
- self.in_features = in_features
- self.out_features = out_features
- if partially_freeze:
- self.weight.requires_grad_(False)
- if bias:
- self.bias.requires_grad_(False)
- if out_additional_features > 0:
- self.additional_fc = nn.Linear(
- in_features=in_features,
- out_features=out_additional_features,
- bias=bias,
- device=device,
- dtype=dtype,
- )
- def forward(self, input: torch.Tensor) -> torch.Tensor:
- output = F.linear(input, self.weight, self.bias)
- if self.out_additional_features > 0:
- additional_features = self.additional_fc(input)
- output = torch.cat((output, additional_features), -1)
- return output
- def extra_repr(self) -> str:
- """Overwriting `nn.Linear.extra_repr` to include new parameters."""
- return "in_features={}, out_features={}, out_additional_features={}, bias={}, partially_freeze={}".format(
- self.in_features,
- self.out_features,
- self.out_additional_features,
- self.bias is not None,
- self.partially_freeze,
- )
- # this was adapted from LlamaRMSNorm
- class IdeficsRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-6):
- """
- IdeficsRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states):
- variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- # convert into half-precision if necessary
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
- hidden_states = hidden_states.to(self.weight.dtype)
- return self.weight * hidden_states
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- ALL_LAYERNORM_LAYERS.append(IdeficsRMSNorm)
- # this was adapted from LlamaRotaryEmbedding
- class IdeficsEmbedding(torch.nn.Module):
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
- super().__init__()
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
- )
- def _set_cos_sin_cache(self, seq_len, device, dtype):
- self.max_seq_len_cached = seq_len
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
- def forward(self, x, seq_len=None):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- # Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`):
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
- used to pass offsetted position ids when working with a KV-cache.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- # this was adapted from LlamaMLP
- class IdeficsMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- ):
- super().__init__()
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
- self.act_fn = ACT2FN[hidden_act]
- def forward(self, x):
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- # this was adapted from LlamaAttention
- class IdeficsAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- dropout: float = 0.0,
- is_cross_attention: bool = False,
- config: PretrainedConfig = None,
- qk_layer_norms: bool = False,
- layer_idx: int = None,
- ):
- super().__init__()
- self.hidden_size = hidden_size
- self.num_heads = num_heads
- self.head_dim = hidden_size // num_heads
- self.dropout = dropout
- self.is_causal = True
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- if (self.head_dim * num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {num_heads})."
- )
- self.is_cross_attention = is_cross_attention
- if not hasattr(nn.functional, "scaled_dot_product_attention"):
- raise ValueError("this model requires pytorch 2.0 or higher")
- if self.is_cross_attention:
- kv_input_dim = (
- self.hidden_size if not hasattr(config.vision_config, "embed_dim") else config.vision_config.embed_dim
- )
- self.q_proj = nn.Linear(
- self.hidden_size,
- num_heads * self.head_dim,
- bias=False,
- )
- self.k_proj = nn.Linear(kv_input_dim, num_heads * self.head_dim, bias=False)
- self.v_proj = nn.Linear(
- kv_input_dim,
- num_heads * self.head_dim,
- bias=False,
- )
- else:
- self.q_proj = nn.Linear(
- self.hidden_size,
- num_heads * self.head_dim,
- bias=False,
- )
- self.k_proj = nn.Linear(
- self.hidden_size,
- num_heads * self.head_dim,
- bias=False,
- )
- self.v_proj = nn.Linear(
- self.hidden_size,
- num_heads * self.head_dim,
- bias=False,
- )
- self.o_proj = nn.Linear(
- num_heads * self.head_dim,
- hidden_size,
- bias=False,
- )
- self.rotary_emb = IdeficsEmbedding(self.head_dim)
- self.qk_layer_norms = qk_layer_norms
- if self.qk_layer_norms:
- self.q_layer_norm = IdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps)
- self.k_layer_norm = IdeficsRMSNorm(self.head_dim, eps=config.rms_norm_eps)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- is_cross_attention = self.is_cross_attention or key_value_states is not None
- bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- if not is_cross_attention:
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- else:
- _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
- key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = (
- self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
- )
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += cache_position[0]
- if not is_cross_attention:
- cos, sin = self.rotary_emb(value_states, seq_len=max(kv_seq_len, q_len))
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
- # [bsz, nh, t, hd]
- if past_key_value is not None:
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
- cache_kwargs = {"cache_position": cache_position}
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- if self.qk_layer_norms:
- query_states = self.q_layer_norm(query_states)
- key_states = self.k_layer_norm(key_states)
- causal_mask = attention_mask
- if attention_mask is not None:
- causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and attention_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
- )
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
- attn_weights = None
- if output_attentions:
- logger.warning_once(
- "attn_weights are not extracted in scaled_dot_product_attention. The model returns None instead"
- )
- return attn_output, attn_weights, past_key_value
- # this was adapted from LlamaDecoderLayer
- class IdeficsDecoderLayer(nn.Module):
- def __init__(self, config: IdeficsConfig, layer_idx: int = None):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = IdeficsAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- dropout=config.dropout,
- config=config,
- layer_idx=layer_idx,
- )
- self.mlp = IdeficsMLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- )
- self.input_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.dropout = config.dropout
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
- class IdeficsGatedCrossAttentionLayer(nn.Module):
- def __init__(self, config: IdeficsConfig):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.cross_attn = IdeficsAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- is_cross_attention=True,
- dropout=config.dropout,
- config=config,
- qk_layer_norms=config.qk_layer_norms,
- )
- self.mlp = IdeficsMLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- )
- self.input_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.config = config.dropout
- self.act_cross_attn = nn.Tanh()
- self.act_dense = nn.Tanh()
- if config.alpha_initializer == "zeros":
- if config.alpha_type == "vector":
- self.alpha_cross_attn = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
- self.alpha_dense = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
- elif config.alpha_type == "float":
- self.alpha_cross_attn = nn.Parameter(torch.zeros(1))
- self.alpha_dense = nn.Parameter(torch.zeros(1))
- else:
- raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
- elif config.alpha_initializer == "ones":
- if config.alpha_type == "vector":
- self.alpha_cross_attn = nn.Parameter(torch.ones(1, 1, self.hidden_size))
- self.alpha_dense = nn.Parameter(torch.ones(1, 1, self.hidden_size))
- elif config.alpha_type == "float":
- self.alpha_cross_attn = nn.Parameter(torch.ones(1))
- self.alpha_dense = nn.Parameter(torch.ones(1))
- else:
- raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
- elif config.alpha_initializer in {"normal", "gaussian", "random"}:
- if config.alpha_type == "vector":
- self.alpha_cross_attn = nn.Parameter(
- torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
- )
- self.alpha_dense = nn.Parameter(
- torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1, 1, self.hidden_size))
- )
- elif config.alpha_type == "float":
- self.alpha_cross_attn = nn.Parameter(
- torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1))
- )
- self.alpha_dense = nn.Parameter(torch.normal(mean=0.0, std=config.alphas_initializer_range, size=(1)))
- else:
- raise ValueError(f"Unknown value for `alpha_type` ({config.alpha_type})")
- else:
- raise NotImplementedError(f"Alpha initialization scheme {config.alpha_initializer} not yet implemented!")
- if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
- raise ValueError("Alpha parameters not initialized correctly!")
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- image_hidden_states: Optional[torch.Tensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- cross_attention_gate: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- image_attention_mask (`torch.FloatTensor`, *optional*): image attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- cross_attention_gate (`torch.FloatTensor`, *optional*):
- gate of size `(batch, seq_len)` used to zero-out cross-attention output for tokens attending no images.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
- """
- if image_hidden_states is None:
- raise ValueError(
- "`image_hidden_states` is required for Idefics cross attention module which are visual features to be"
- " conditioned on."
- )
- if cross_attention_gate is None:
- raise ValueError(
- "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images."
- )
- if past_key_value is not None:
- raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.cross_attn(
- hidden_states=hidden_states,
- key_value_states=image_hidden_states,
- attention_mask=image_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
- # Fill in zeros for cross_attention hidden_states of tokens attending to no images
- hidden_states[cross_attention_gate == 0] = hidden_states[cross_attention_gate == 0].fill_(0)
- hidden_states = residual + self.act_cross_attn(self.alpha_cross_attn) * hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
- hidden_states = residual + self.act_dense(self.alpha_dense) * hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
- LLAMA_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`IdeficsConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- @add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
- LLAMA_START_DOCSTRING,
- )
- class IdeficsPreTrainedModel(PreTrainedModel):
- config_class = IdeficsConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["IdeficsDecoderLayer", "IdeficsGatedCrossAttentionLayer"]
- _supports_sdpa = True
- _supports_cache_class = True
- _supports_static_cache = True
- def _init_weights(self, module):
- # important: this ported version of Idefics isn't meant for training from scratch - only
- # inference and fine-tuning - so the proper init weights code has been removed - the m4 code
- # base should be used for training from scratch and it contains the correct code.
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- LLAMA_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
- past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
- `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
- `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
- Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
- """
- @add_start_docstrings(
- "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
- LLAMA_START_DOCSTRING,
- )
- class IdeficsModel(IdeficsPreTrainedModel):
- """
- Transformer decoder consisting of `config.num_hidden_layers` layers. Each layer is a [`IdeficsDecoderLayer`]
- Args:
- config: IdeficsConfig
- """
- def __init__(self, config: IdeficsConfig):
- super().__init__(config)
- self.config = config
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = IdeficsDecoupledEmbedding(
- num_embeddings=config.vocab_size,
- num_additional_embeddings=config.additional_vocab_size,
- embedding_dim=config.hidden_size,
- partially_freeze=config.freeze_text_layers,
- padding_idx=self.padding_idx,
- )
- self.image_size = config.vision_config.image_size
- self.vision_config = config.vision_config
- self.vision_model = IdeficsVisionTransformer(config.vision_config)
- # Perceiver Resampler
- if config.use_resampler:
- perceiver_config = config.perceiver_config
- self.perceiver_resampler = IdeficsPerceiverResampler(
- config,
- config.vision_config.embed_dim,
- perceiver_config.resampler_depth,
- perceiver_config.resampler_n_heads,
- perceiver_config.resampler_head_dim,
- perceiver_config.resampler_n_latents,
- )
- self.layers = nn.ModuleList(
- [IdeficsDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
- )
- self.cross_layer_interval = config.cross_layer_interval
- num_cross_layers = config.num_hidden_layers // self.cross_layer_interval
- self.gated_cross_attn_layers = nn.ModuleList(
- [IdeficsGatedCrossAttentionLayer(config) for _ in range(num_cross_layers)]
- )
- self.gradient_checkpointing = False
- self.norm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- # Initialize weights and apply final processing
- self.post_init()
- self.freeze_relevant_params(config)
- def freeze_relevant_params(self, config=None):
- if config is None:
- config = self.config
- if config.freeze_text_layers:
- self.freeze_text_layers(config.freeze_text_module_exceptions)
- if config.freeze_vision_layers:
- freeze_model(self.vision_model, module_exceptions=config.freeze_vision_module_exceptions)
- def freeze_text_layers(self, module_exceptions=[]):
- for module in [self.layers, self.norm]:
- freeze_model(module, module_exceptions=module_exceptions)
- def freeze_vision_layers(self, module_exceptions=[]):
- freeze_model(self.vision_model, module_exceptions=module_exceptions)
- def get_input_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, value):
- self.embed_tokens = value
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_encoder_embeddings: Optional[torch.FloatTensor] = None,
- perceiver_embeddings: Optional[torch.FloatTensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: Optional[bool] = False,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, IdeficsBaseModelOutputWithPast]:
- device = input_ids.device if input_ids is not None else inputs_embeds.device
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training and use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
- )
- use_cache = False
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = False
- if use_cache and not isinstance(past_key_values, Cache):
- return_legacy_cache = True
- if past_key_values is None:
- past_key_values = DynamicCache()
- else:
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- logger.warning_once(
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
- )
- batch_size, seq_length, _ = inputs_embeds.shape
- past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
- seq_length_with_past = seq_length + past_key_values_length
- if cache_position is None:
- cache_position = torch.arange(
- past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- position_ids = position_ids[:, -seq_length:]
- elif position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- if (pixel_values, image_encoder_embeddings, perceiver_embeddings).count(None) != 2:
- raise ValueError(
- "Exactly 1 of pixel_values, image_encoder_embeddings or perceiver_embeddings has to be not-None."
- )
- elif pixel_values is not None:
- pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
- batch_size, num_images = pixel_values.shape[:2]
- pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
- # Get sequence from the vision encoder
- image_hidden_states = self.vision_model(
- pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
- ).last_hidden_state
- elif image_encoder_embeddings is not None:
- batch_size, num_images, image_seq_len, image_hidden_size = image_encoder_embeddings.size()
- image_hidden_states = image_encoder_embeddings.to(dtype=self.dtype, device=device)
- image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
- if self.config.use_resampler:
- if perceiver_embeddings is None:
- perceiver_embeddings = self.perceiver_resampler(image_hidden_states)
- image_seq_len, image_hidden_size = perceiver_embeddings.size(1), perceiver_embeddings.size(2)
- else:
- batch_size, num_images, image_seq_len, image_hidden_size = perceiver_embeddings.size()
- image_hidden_states = perceiver_embeddings
- elif perceiver_embeddings is None:
- image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
- else:
- raise ValueError("If `perceiver_embeddings` are passed, use_resampler should be True")
- image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
- # # Hack to use the model in full language modeling mode
- # image_attention_mask = torch.zeros(batch_size, seq_length, 1, dtype=torch.long, device=image_hidden_states.device)
- # Make image_attention_mask compatible with hidden states
- text_seq_len = image_attention_mask.size(1)
- image_attention_mask = image_attention_mask.unsqueeze(-1)
- image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
- image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
- if image_hidden_states is not None:
- image_batch_size, image_sequence_length, _ = image_hidden_states.size()
- image_hidden_shape = (image_batch_size, image_sequence_length)
- if image_attention_mask is None:
- image_attention_mask = torch.ones(image_hidden_shape, device=device)
- image_attention_mask = self.invert_attention_mask(image_attention_mask)
- else:
- image_attention_mask = None
- # cross_attention_gate:
- # For any tokens attending to no images, the hidden_states comming out of the cross-attention should be zeroed-out.
- # `image_attention_mask` has shape [bsz, 1, num_images, hidden_size] with elements equal to either 0.0 or a very negative number.
- # If any of the elements are 0.0, then the token is attending to at least one image and the gate value is 1. Otherwise the gate value is 0.
- # `cross_attention_gate` has shape [bsz, seq_len] with elements equal to either 0.0 or 1.0.
- cross_attention_gate = ((((image_attention_mask == 0.0).any(dim=-1)).to(dtype=self.dtype)).squeeze(dim=1)).to(
- device
- )
- # embed positions
- if attention_mask is None:
- attention_mask = torch.ones(
- (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
- )
- attention_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- hidden_states = inputs_embeds
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
- for idx, decoder_layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- def vblock(
- main_block,
- hidden_states,
- attention_mask,
- position_ids,
- past_key_value,
- image_hidden_states,
- image_attention_mask,
- cross_attention_gate,
- output_attentions,
- use_cache,
- layer_idx,
- cross_layer_interval,
- gated_cross_attn_layers,
- cache_position,
- ):
- # TODO(ls): Add cross attention values to respective lists
- if layer_idx % cross_layer_interval == 0:
- xblock = gated_cross_attn_layers[layer_idx // cross_layer_interval]
- outputs = xblock(
- hidden_states,
- attention_mask=attention_mask,
- image_hidden_states=image_hidden_states,
- image_attention_mask=image_attention_mask,
- cross_attention_gate=cross_attention_gate,
- output_attentions=output_attentions,
- use_cache=use_cache,
- past_key_value=None, # not implemented
- )
- hidden_states = outputs[0]
- layer_outputs = main_block(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- )
- return layer_outputs
- if self.gradient_checkpointing and self.training:
- past_key_values = None
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- layer_outputs = self._gradient_checkpointing_func(
- vblock,
- decoder_layer,
- hidden_states,
- attention_mask,
- position_ids,
- past_key_values,
- image_hidden_states,
- image_attention_mask,
- cross_attention_gate,
- output_attentions,
- use_cache,
- idx,
- self.cross_layer_interval,
- self.gated_cross_attn_layers,
- cache_position,
- )
- else:
- layer_outputs = vblock(
- decoder_layer,
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- image_hidden_states=image_hidden_states,
- image_attention_mask=image_attention_mask,
- cross_attention_gate=cross_attention_gate,
- output_attentions=output_attentions,
- use_cache=use_cache,
- layer_idx=idx,
- cross_layer_interval=self.cross_layer_interval,
- gated_cross_attn_layers=self.gated_cross_attn_layers,
- cache_position=cache_position,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.norm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
- image_hidden_states = image_hidden_states.view(batch_size, num_images, image_seq_len, image_hidden_size)
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, image_hidden_states]
- if v is not None
- )
- return IdeficsBaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- image_hidden_states=image_hidden_states,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
- return causal_mask
- @staticmethod
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
- class IdeficsForVisionText2Text(IdeficsPreTrainedModel, GenerationMixin):
- _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
- _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
- def __init__(self, config, vision_model=None):
- super().__init__(config)
- self.model = IdeficsModel(config)
- self.lm_head = IdeficsDecoupledLinear(
- in_features=config.hidden_size,
- out_features=config.vocab_size,
- out_additional_features=config.additional_vocab_size,
- bias=False,
- partially_freeze=config.freeze_lm_head,
- )
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.model.embed_tokens
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def set_decoder(self, decoder):
- self.model = decoder
- def get_decoder(self):
- return self.model
- def tie_weights(self):
- """
- Overwrite `transformers.modeling_utils.PreTrainedModel.tie_weights` to handle the case of
- IdeficsDecoupledLinear and IdeficsDecoupledEmbedding.
- """
- output_embeddings = self.get_output_embeddings()
- input_embeddings = self.get_input_embeddings()
- if getattr(self.config, "tie_word_embeddings", True):
- output_embeddings.weight = input_embeddings.weight
- if input_embeddings.num_additional_embeddings > 0:
- assert output_embeddings.out_additional_features == input_embeddings.num_additional_embeddings
- output_embeddings.additional_fc.weight = input_embeddings.additional_embedding.weight
- if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
- output_embeddings.out_features = input_embeddings.num_embeddings
- if hasattr(output_embeddings, "out_additional_features") and hasattr(
- input_embeddings, "num_additional_embeddings"
- ):
- output_embeddings.out_additional_features = input_embeddings.num_additional_embeddings
- @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=IdeficsCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- pixel_values: Optional[torch.FloatTensor] = None,
- image_encoder_embeddings: Optional[torch.FloatTensor] = None,
- perceiver_embeddings: Optional[torch.FloatTensor] = None,
- image_attention_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: Optional[bool] = False,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, IdeficsCausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- Returns:
- Example:
- ```python
- >>> from transformers import AutoProcessor, IdeficsForVisionText2Text
- >>> model = IdeficsForVisionText2Text.from_pretrained("HuggingFaceM4/idefics-9b")
- >>> processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics-9b")
- >>> dogs_image_url_1 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image1.jpeg"
- >>> dogs_image_url_2 = "https://huggingface.co/datasets/hf-internal-testing/fixtures_nlvr2/raw/main/image2.jpeg"
- >>> prompts = [
- ... [
- ... "User:",
- ... dogs_image_url_1,
- ... "Describe this image.\nAssistant: An image of two dogs.\n",
- ... "User:",
- ... dogs_image_url_2,
- ... "Describe this image.\nAssistant:",
- ... ]
- ... ]
- >>> inputs = processor(prompts, return_tensors="pt")
- >>> generate_ids = model.generate(**inputs, max_new_tokens=6)
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True)
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- pixel_values=pixel_values,
- image_encoder_embeddings=image_encoder_embeddings,
- perceiver_embeddings=perceiver_embeddings,
- image_attention_mask=image_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- cache_position=cache_position,
- )
- hidden_states = outputs[0]
- logits = self.lm_head(hidden_states)
- loss = None
- if labels is not None:
- labels = labels.to(logits.device)
- # Shift so that tokens < n predict n
- if attention_mask is not None:
- # we use the input attention mask to shift the logits and labels, because it is 2D.
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
- shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to(logits.device)
- shift_logits = logits[..., :-1, :][shift_attention_mask != 0].contiguous()
- shift_labels = labels[..., 1:][shift_attention_mask != 0].contiguous()
- else:
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
- return IdeficsCausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- image_hidden_states=outputs.image_hidden_states,
- )
- def prepare_inputs_for_generation(
- self,
- input_ids,
- attention_mask=None,
- position_ids=None,
- inputs_embeds=None,
- past_key_values=None,
- cache_position=None,
- pixel_values=None,
- image_hidden_states=None,
- image_attention_mask=None,
- use_cache=None,
- **kwargs,
- ):
- # Overwritten -- custom processing based on `config.use_resampler`
- model_inputs = {}
- if image_hidden_states is not None:
- if self.config.use_resampler:
- model_inputs["perceiver_embeddings"] = image_hidden_states
- else:
- model_inputs["image_encoder_embeddings"] = image_hidden_states
- else:
- model_inputs["pixel_values"] = pixel_values
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
- if past_key_values is not None:
- if inputs_embeds is not None:
- input_ids = input_ids[:, -cache_position.shape[0] :]
- elif input_ids.shape[1] != cache_position.shape[0]:
- input_ids = input_ids[:, cache_position]
- if image_attention_mask is not None:
- image_attention_mask = image_attention_mask[:, -input_ids.shape[1] :]
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[:, -input_ids.shape[1] :]
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
- position_ids = position_ids.clone(memory_format=torch.contiguous_format)
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
- if inputs_embeds is not None and cache_position[0] == 0:
- model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
- else:
- # The clone here is for the same reason as for `position_ids`.
- model_inputs.update(
- {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
- )
- model_inputs.update(
- {
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- "cache_position": cache_position,
- "position_ids": position_ids,
- "attention_mask": attention_mask,
- "image_attention_mask": image_attention_mask,
- "interpolate_pos_encoding": kwargs.get("interpolate_pos_encoding", False),
- }
- )
- return model_inputs
- def _update_model_kwargs_for_generation(
- self,
- outputs: ModelOutput,
- model_kwargs: Dict[str, Any],
- is_encoder_decoder: bool = False,
- **kwargs,
- ) -> Dict[str, Any]:
- model_kwargs = super()._update_model_kwargs_for_generation(
- outputs,
- model_kwargs,
- is_encoder_decoder,
- **kwargs,
- )
- if "image_attention_mask" in model_kwargs:
- image_attention_mask = model_kwargs["image_attention_mask"]
- last_mask = image_attention_mask[:, -1, :].unsqueeze(1)
- if model_kwargs.get("use_cache", True):
- model_kwargs["image_attention_mask"] = last_mask
- else:
- model_kwargs["image_attention_mask"] = torch.cat([image_attention_mask, last_mask], dim=1)
- # Get the precomputed image_hidden_states
- model_kwargs["image_hidden_states"] = outputs.image_hidden_states
- return model_kwargs
- @staticmethod
- def _reorder_cache(past, beam_idx):
- reordered_past = ()
- for layer_past in past:
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
- return reordered_past
|