| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646 |
- # coding=utf-8
- # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
- # All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch GIT model."""
- import math
- from dataclasses import dataclass
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...file_utils import ModelOutput
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
- from ...modeling_outputs import (
- BaseModelOutput,
- BaseModelOutputWithPast,
- BaseModelOutputWithPooling,
- CausalLMOutputWithPast,
- )
- from ...modeling_utils import PreTrainedModel
- from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
- from ...utils import (
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- torch_int,
- )
- from .configuration_git import GitConfig, GitVisionConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "microsoft/git-base"
- _CONFIG_FOR_DOC = "GitConfig"
- @dataclass
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
- class GitVisionModelOutput(ModelOutput):
- """
- Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
- Args:
- image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
- The image embeddings obtained by applying the projection layer to the pooler_output.
- last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Sequence of hidden-states at the output of the last layer of the model.
- hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
- one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- image_embeds: Optional[torch.FloatTensor] = None
- last_hidden_state: torch.FloatTensor = None
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
- class GitEmbeddings(nn.Module):
- """Construct the embeddings from word and position embeddings."""
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- past_key_values_length: int = 0,
- ) -> torch.Tensor:
- if input_ids is not None:
- input_shape = input_ids.size()
- else:
- input_shape = inputs_embeds.size()[:-1]
- seq_length = input_shape[1]
- if position_ids is None:
- position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
- if inputs_embeds is None:
- embeddings = self.word_embeddings(input_ids)
- else:
- embeddings = inputs_embeds
- if self.position_embedding_type == "absolute":
- position_embeddings = self.position_embeddings(position_ids)
- embeddings += position_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
- class GitSelfAttention(nn.Module):
- def __init__(self, config, position_embedding_type=None, layer_idx=None):
- super().__init__()
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
- if config.num_image_with_embedding is not None:
- self.image_patch_tokens *= config.num_image_with_embedding
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- self.position_embedding_type = position_embedding_type or getattr(
- config, "position_embedding_type", "absolute"
- )
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- self.max_position_embeddings = config.max_position_embeddings
- self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(new_x_shape)
- return x.permute(0, 2, 1, 3)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- pixel_values_present: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- mixed_query_layer = self.query(hidden_states)
- cutoff = self.image_patch_tokens if pixel_values_present else 0
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- if past_key_value is not None:
- # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
- key_layer_past, value_layer_past = past_key_value.update(
- key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
- )
- key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
- value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
- query_layer = self.transpose_for_scores(mixed_query_layer)
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
- query_length, key_length = query_layer.shape[2], key_layer.shape[2]
- if past_key_value is not None:
- position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
- -1, 1
- )
- else:
- position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
- position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
- distance = position_ids_l - position_ids_r
- positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
- positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
- if self.position_embedding_type == "relative_key":
- relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores
- elif self.position_embedding_type == "relative_key_query":
- relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
- relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
- attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- if attention_mask is not None:
- # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
- attention_scores = attention_scores + attention_mask
- # Normalize the attention scores to probabilities.
- attention_probs = nn.functional.softmax(attention_scores, dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
- # Mask heads if we want to
- if head_mask is not None:
- attention_probs = attention_probs * head_mask
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(new_context_layer_shape)
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
- outputs = outputs + (past_key_value,)
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
- class GitSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- GIT_SELF_ATTENTION_CLASSES = {
- "eager": GitSelfAttention,
- }
- class GitAttention(nn.Module):
- def __init__(self, config, position_embedding_type=None, layer_idx=None):
- super().__init__()
- self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
- config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
- )
- self.output = GitSelfOutput(config)
- self.pruned_heads = set()
- # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
- def prune_heads(self, heads):
- if len(heads) == 0:
- return
- heads, index = find_pruneable_heads_and_indices(
- heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
- )
- # Prune linear layers
- self.self.query = prune_linear_layer(self.self.query, index)
- self.self.key = prune_linear_layer(self.self.key, index)
- self.self.value = prune_linear_layer(self.self.value, index)
- self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
- # Update hyper params and store pruned heads
- self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
- self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
- self.pruned_heads = self.pruned_heads.union(heads)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- pixel_values_present: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- self_outputs = self.self(
- hidden_states,
- attention_mask,
- head_mask,
- past_key_value,
- output_attentions,
- pixel_values_present,
- )
- attention_output = self.output(self_outputs[0], hidden_states)
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
- return outputs
- # Copied from transformers.models.bert.modeling_bert.BertIntermediate
- class GitIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- # Copied from transformers.models.bert.modeling_bert.BertOutput
- class GitOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class GitLayer(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = GitAttention(config, layer_idx=layer_idx)
- self.intermediate = GitIntermediate(config)
- self.output = GitOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- pixel_values_present: Optional[bool] = False,
- ) -> Tuple[torch.Tensor]:
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attention_outputs = self.attention(
- hidden_states,
- attention_mask,
- head_mask,
- output_attentions=output_attentions,
- past_key_value=past_key_value,
- pixel_values_present=pixel_values_present,
- )
- attention_output = self_attention_outputs[0]
- # if decoder, the last output is tuple of self-attn cache
- outputs = self_attention_outputs[1:-1]
- present_key_value = self_attention_outputs[-1]
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
- )
- outputs = (layer_output,) + outputs
- # if decoder, return the attn key/values as the last output
- outputs = outputs + (present_key_value,)
- return outputs
- def feed_forward_chunk(self, attention_output):
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class GitEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.FloatTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = False,
- output_hidden_states: Optional[bool] = False,
- pixel_values_present: Optional[bool] = False,
- return_dict: Optional[bool] = True,
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = False
- if use_cache and not isinstance(past_key_values, Cache):
- return_legacy_cache = True
- if past_key_values is None:
- past_key_values = DynamicCache()
- else:
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- logger.warning_once(
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
- )
- all_hidden_states = () if output_hidden_states else None
- all_self_attentions = () if output_attentions else None
- next_decoder_cache = None
- for i, layer_module in enumerate(self.layer):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- layer_head_mask = head_mask[i] if head_mask is not None else None
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- layer_module.__call__,
- hidden_states,
- attention_mask,
- layer_head_mask,
- past_key_values,
- output_attentions,
- )
- else:
- layer_outputs = layer_module(
- hidden_states,
- attention_mask,
- layer_head_mask,
- past_key_values,
- output_attentions,
- pixel_values_present,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache = layer_outputs[-1]
- if output_attentions:
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
- if not return_dict:
- return tuple(
- v
- for v in [
- hidden_states,
- next_cache,
- all_hidden_states,
- all_self_attentions,
- ]
- if v is not None
- )
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- class GitPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = GitConfig
- base_model_prefix = "git"
- supports_gradient_checkpointing = True
- _supports_cache_class = True
- _supports_quantized_cache = True
- def _init_weights(self, module):
- """Initialize the weights"""
- if isinstance(module, GitVisionEmbeddings):
- nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
- nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
- nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
- if isinstance(module, nn.Linear):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- GIT_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`GitConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- GIT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- [What are position IDs?](../glossary#position-ids)
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
- [`CLIPImageProcessor.__call__`] for details.
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
- Whether to interpolate the pre-trained position encodings.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
- class GitVisionEmbeddings(nn.Module):
- def __init__(self, config: GitVisionConfig):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.image_size = config.image_size
- self.patch_size = config.patch_size
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
- self.patch_embedding = nn.Conv2d(
- in_channels=config.num_channels,
- out_channels=self.embed_dim,
- kernel_size=self.patch_size,
- stride=self.patch_size,
- bias=False,
- )
- self.num_patches = (self.image_size // self.patch_size) ** 2
- self.num_positions = self.num_patches + 1
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
- """
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
- images. This method is also adapted to support torch.jit tracing.
- Adapted from:
- - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
- """
- num_patches = embeddings.shape[1] - 1
- position_embedding = self.position_embedding.weight.unsqueeze(0)
- num_positions = position_embedding.shape[1] - 1
- # always interpolate when tracing to ensure the exported model works for dynamic input shapes
- if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
- return self.position_embedding(self.position_ids)
- class_pos_embed = position_embedding[:, :1]
- patch_pos_embed = position_embedding[:, 1:]
- dim = embeddings.shape[-1]
- new_height = height // self.patch_size
- new_width = width // self.patch_size
- sqrt_num_positions = torch_int(num_positions**0.5)
- patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
- patch_pos_embed = nn.functional.interpolate(
- patch_pos_embed,
- size=(new_height, new_width),
- mode="bicubic",
- align_corners=False,
- )
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
- return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
- batch_size, _, height, width = pixel_values.shape
- if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
- raise ValueError(
- f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
- )
- target_dtype = self.patch_embedding.weight.dtype
- patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
- if interpolate_pos_encoding:
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
- else:
- embeddings = embeddings + self.position_embedding(self.position_ids)
- return embeddings
- # Copied from transformers.models.clip.modeling_clip.CLIPMLP
- class GitVisionMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.activation_fn = ACT2FN[config.hidden_act]
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- return hidden_states
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->GitVision
- class GitVisionAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.embed_dim = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
- f" {self.num_heads})."
- )
- self.scale = self.head_dim**-0.5
- self.dropout = config.attention_dropout
- self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
- """Input shape: Batch x Time x Channel"""
- bsz, tgt_len, embed_dim = hidden_states.size()
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scale
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.view(*proj_shape)
- value_states = value_states.view(*proj_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
- # apply the causal_attention_mask first
- if causal_attention_mask is not None:
- if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
- f" {causal_attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- if output_attentions:
- # this operation is a bit akward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
- attn_output = torch.bmm(attn_probs, value_states)
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
- attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
- class GitVisionEncoderLayer(nn.Module):
- def __init__(self, config: GitVisionConfig):
- super().__init__()
- self.embed_dim = config.hidden_size
- self.self_attn = GitVisionAttention(config)
- self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- self.mlp = GitVisionMLP(config)
- self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- causal_attention_mask: torch.Tensor,
- output_attentions: Optional[bool] = False,
- ) -> Tuple[torch.FloatTensor]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- `(config.encoder_attention_heads,)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- """
- residual = hidden_states
- hidden_states = self.layer_norm1(hidden_states)
- hidden_states, attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- causal_attention_mask=causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = residual + hidden_states
- residual = hidden_states
- hidden_states = self.layer_norm2(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
- class GitVisionEncoder(nn.Module):
- """
- Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
- [`GitVisionEncoderLayer`].
- Args:
- config: GitVisionConfig
- """
- def __init__(self, config: GitVisionConfig):
- super().__init__()
- self.config = config
- self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
- self.gradient_checkpointing = False
- def forward(
- self,
- inputs_embeds,
- attention_mask: Optional[torch.Tensor] = None,
- causal_attention_mask: Optional[torch.Tensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
- r"""
- Args:
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
- This is useful if you want more control over how to convert `input_ids` indices into associated vectors
- than the model's internal embedding lookup matrix.
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Causal mask for the text model. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
- for more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- encoder_states = () if output_hidden_states else None
- all_attentions = () if output_attentions else None
- hidden_states = inputs_embeds
- for idx, encoder_layer in enumerate(self.layers):
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- encoder_layer.__call__,
- hidden_states,
- attention_mask,
- causal_attention_mask,
- output_attentions,
- )
- else:
- layer_outputs = encoder_layer(
- hidden_states,
- attention_mask,
- causal_attention_mask,
- output_attentions=output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions = all_attentions + (layer_outputs[1],)
- if output_hidden_states:
- encoder_states = encoder_states + (hidden_states,)
- if not return_dict:
- return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
- return BaseModelOutput(
- last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
- )
- GIT_VISION_INPUTS_DOCSTRING = r"""
- Args:
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
- Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
- [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
- Whether to interpolate the pre-trained position encodings.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- class GitVisionTransformer(nn.Module):
- # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
- def __init__(self, config: GitVisionConfig):
- super().__init__()
- self.config = config
- embed_dim = config.hidden_size
- self.embeddings = GitVisionEmbeddings(config)
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- self.encoder = GitVisionEncoder(config)
- self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
- @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: Optional[bool] = False,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
- r"""
- Returns:
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if pixel_values is None:
- raise ValueError("You have to specify pixel_values")
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
- hidden_states = self.pre_layrnorm(hidden_states)
- encoder_outputs = self.encoder(
- inputs_embeds=hidden_states,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- last_hidden_state = encoder_outputs[0]
- last_hidden_state = self.post_layernorm(last_hidden_state)
- if not return_dict:
- return (last_hidden_state,) + encoder_outputs[1:]
- return BaseModelOutput(
- last_hidden_state=last_hidden_state,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @add_start_docstrings(
- """The vision model from CLIP, used in GIT, without any head or projection on top.""",
- GIT_START_DOCSTRING,
- )
- class GitVisionModel(GitPreTrainedModel):
- config_class = GitVisionConfig
- main_input_name = "pixel_values"
- # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
- def __init__(self, config: GitVisionConfig):
- super().__init__(config)
- self.vision_model = GitVisionTransformer(config)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self) -> nn.Module:
- return self.vision_model.embeddings.patch_embedding
- @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
- def forward(
- self,
- pixel_values: Optional[torch.FloatTensor] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: bool = False,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, BaseModelOutput]:
- r"""
- Returns:
- Examples:
- ```python
- >>> from PIL import Image
- >>> import requests
- >>> from transformers import AutoProcessor, GitVisionModel
- >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
- >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> inputs = processor(images=image, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- ```"""
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- return self.vision_model(
- pixel_values=pixel_values,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- class GitProjection(nn.Module):
- def __init__(self, config: GitConfig):
- super().__init__()
- self.config = config
- self.visual_projection = nn.Sequential(
- nn.Linear(config.vision_config.hidden_size, config.hidden_size),
- nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
- )
- def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
- return self.visual_projection(embeddings)
- @add_start_docstrings(
- "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
- " without any specific head on top.",
- GIT_START_DOCSTRING,
- )
- class GitModel(GitPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.config = config
- self.embeddings = GitEmbeddings(config)
- self.image_encoder = GitVisionModel(config.vision_config)
- self.encoder = GitEncoder(config)
- self.visual_projection = GitProjection(config)
- if config.num_image_with_embedding is not None:
- self.img_temperal_embedding = nn.ParameterList(
- nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
- for _ in range(config.num_image_with_embedding)
- )
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- def _prune_heads(self, heads_to_prune):
- """
- Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
- class PreTrainedModel
- """
- for layer, heads in heads_to_prune.items():
- self.encoder.layer[layer].attention.prune_heads(heads)
- def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
- # Default mask is for forward direction. Flip for backward direction.
- mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
- mask = mask.masked_fill(mask == 1, float("-inf"))
- return mask
- def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
- num_tgt = tgt.shape[1]
- num_memory = memory.shape[1]
- device = tgt.device
- dtype = tgt.dtype
- top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
- top_right = torch.full(
- (num_memory, num_tgt + past_key_values_length),
- float("-inf"),
- device=tgt.device,
- dtype=dtype,
- )
- bottom_left = torch.zeros(
- (num_tgt, num_memory),
- dtype=dtype,
- device=tgt_mask.device,
- )
- if past_key_values_length > 0:
- tgt_mask = torch.zeros(
- (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
- dtype=dtype,
- device=tgt_mask.device,
- )
- left = torch.cat((top_left, bottom_left), dim=0)
- right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
- full_attention_mask = torch.cat((left, right), dim=1)[None, :]
- if memory_key_padding_mask is None:
- memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
- # if it is False, it means valid. That is, it is not a padding
- if memory_key_padding_mask.dtype != torch.bool:
- raise ValueError("Memory key padding mask must be a boolean tensor.")
- zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
- zero_negative_infinity[memory_key_padding_mask] = float("-inf")
- full_attention_mask = full_attention_mask.expand(
- (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
- )
- full_attention_mask = full_attention_mask.clone()
- origin_left = full_attention_mask[:, :, :num_memory]
- update = zero_negative_infinity[:, None, :]
- full_attention_mask[:, :, :num_memory] = origin_left + update
- # add axis for multi-head
- full_attention_mask = full_attention_mask[:, None, :, :]
- return full_attention_mask
- @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- pixel_values: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: bool = False,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
- r"""
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- Returns:
- Examples:
- ```python
- >>> from transformers import AutoProcessor, AutoModel
- >>> import requests
- >>> from PIL import Image
- >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
- >>> model = AutoModel.from_pretrained("microsoft/git-base")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> text = "this is an image of two cats"
- >>> inputs = processor(images=image, text=text, return_tensors="pt")
- >>> outputs = model(**inputs)
- >>> last_hidden_state = outputs.last_hidden_state
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
- seq_length = input_shape[1]
- # past_key_values_length
- past_key_values_length = 0
- if past_key_values is not None:
- past_key_values_length = (
- past_key_values[0][0].shape[2]
- if not isinstance(past_key_values, Cache)
- else past_key_values.get_seq_length()
- )
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x n_heads x N x N
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
- projected_visual_features = None
- if pixel_values is not None:
- if pixel_values.ndim == 4:
- # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
- visual_features = self.image_encoder(
- pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
- ).last_hidden_state
- elif pixel_values.ndim == 5:
- # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
- visual_features = []
- for frame_idx in range(pixel_values.shape[1]):
- visual_features_frame = self.image_encoder(
- pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding
- ).last_hidden_state
- visual_features_frame += self.img_temperal_embedding[frame_idx]
- visual_features.append(visual_features_frame)
- # finally, concatenate all features along sequence dimension
- visual_features = torch.cat(visual_features, dim=1)
- else:
- raise ValueError("pixel_values must be of rank 4 or 5")
- projected_visual_features = self.visual_projection(visual_features)
- embedding_output = self.embeddings(
- input_ids=input_ids,
- position_ids=position_ids,
- inputs_embeds=inputs_embeds,
- past_key_values_length=past_key_values_length,
- )
- if projected_visual_features is None:
- projected_visual_features = torch.zeros(
- (embedding_output.shape[0], 0, embedding_output.shape[2]),
- dtype=embedding_output.dtype,
- device=embedding_output.device,
- )
- # Repeat visual features to match embedding batch size.
- projected_visual_features = projected_visual_features.repeat(
- embedding_output.size(0) // projected_visual_features.size(0), 1, 1
- )
- # concatenate patch token and text token embeddings
- hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
- # By default, an additive causal mask is created
- # for masking the future (one direction).
- tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
- # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
- combined_attention_mask = self.create_attention_mask(
- tgt=embedding_output,
- memory=projected_visual_features,
- tgt_mask=tgt_mask,
- past_key_values_length=past_key_values_length,
- )
- if attention_mask is not None:
- # if the user provides an attention mask, we add it to the default one
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _prepare_4d_attention_mask(
- attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
- ).to(embedding_output.device)
- if past_key_values_length > 0:
- expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
- else:
- combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
- encoder_outputs = self.encoder(
- hidden_states,
- attention_mask=combined_attention_mask,
- head_mask=head_mask,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- pixel_values_present=pixel_values is not None,
- )
- sequence_output = encoder_outputs[0]
- if not return_dict:
- return (sequence_output,) + encoder_outputs[1:]
- return BaseModelOutputWithPast(
- last_hidden_state=sequence_output,
- past_key_values=encoder_outputs.past_key_values,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- )
- @add_start_docstrings(
- """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
- )
- class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["output.weight"]
- def __init__(self, config):
- super().__init__(config)
- self.git = GitModel(config)
- self.output = nn.Linear(config.hidden_size, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.output
- def set_output_embeddings(self, new_embeddings):
- self.output = new_embeddings
- @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.Tensor] = None,
- pixel_values: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.Tensor]]] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- interpolate_pos_encoding: bool = False,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
- `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
- ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- Returns:
- Examples:
- Image captioning example:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForCausalLM
- >>> import requests
- >>> from PIL import Image
- >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
- >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
- >>> image = Image.open(requests.get(url, stream=True).raw)
- >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
- >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
- >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
- >>> print(generated_caption)
- two cats sleeping on a pink blanket next to remotes.
- ```
- Visual question answering (VQA) example:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForCausalLM
- >>> from huggingface_hub import hf_hub_download
- >>> from PIL import Image
- >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
- >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
- >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
- >>> image = Image.open(file_path).convert("RGB")
- >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
- >>> question = "what does the front of the bus say at the top?"
- >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
- >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
- >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
- >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
- >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
- ['what does the front of the bus say at the top? special']
- ```
- Video captioning example:
- ```python
- >>> import av
- >>> import numpy as np
- >>> from PIL import Image
- >>> from huggingface_hub import hf_hub_download
- >>> from transformers import AutoProcessor, AutoModelForCausalLM
- >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
- >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
- >>> # set seed for reproducability
- >>> np.random.seed(45)
- >>> def read_video_pyav(container, indices):
- ... '''
- ... Decode the video with PyAV decoder.
- ... Args:
- ... container (`av.container.input.InputContainer`): PyAV container.
- ... indices (`List[int]`): List of frame indices to decode.
- ... Returns:
- ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
- ... '''
- ... frames = []
- ... container.seek(0)
- ... start_index = indices[0]
- ... end_index = indices[-1]
- ... for i, frame in enumerate(container.decode(video=0)):
- ... if i > end_index:
- ... break
- ... if i >= start_index and i in indices:
- ... frames.append(frame)
- ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
- >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
- ... '''
- ... Sample a given number of frame indices from the video.
- ... Args:
- ... clip_len (`int`): Total number of frames to sample.
- ... frame_sample_rate (`int`): Sample every n-th frame.
- ... seg_len (`int`): Maximum allowed index of sample's last frame.
- ... Returns:
- ... indices (`List[int]`): List of sampled frame indices
- ... '''
- ... converted_len = int(clip_len * frame_sample_rate)
- ... end_idx = np.random.randint(converted_len, seg_len)
- ... start_idx = end_idx - converted_len
- ... indices = np.linspace(start_idx, end_idx, num=clip_len)
- ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
- ... return indices
- >>> # load video
- >>> file_path = hf_hub_download(
- ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
- ... )
- >>> container = av.open(file_path)
- >>> # sample frames
- >>> num_frames = model.config.num_image_with_embedding
- >>> indices = sample_frame_indices(
- ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
- ... )
- >>> frames = read_video_pyav(container, indices)
- >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
- >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
- >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
- Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
- ```
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if labels is not None:
- use_cache = False
- outputs = self.git(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- pixel_values=pixel_values,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- past_key_values=past_key_values,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- interpolate_pos_encoding=interpolate_pos_encoding,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- logits = self.output(sequence_output)
- loss = None
- if labels is not None:
- # we are doing next-token prediction; shift prediction scores and input ids by one
- num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
- shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
- labels = labels[:, 1:].contiguous()
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
- if not return_dict:
- output = (logits,) + outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- def prepare_inputs_for_generation(
- self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
- ):
- # Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
- # cut decoder_input_ids if past_key_values is used
- if past_key_values is not None:
- past_length = past_key_values.get_seq_length()
- # Some generation methods already pass only the last input ID
- if input_ids.shape[1] > past_length:
- remove_prefix_length = past_length
- else:
- # Default to old behavior: keep only final ID
- remove_prefix_length = input_ids.shape[1] - 1
- input_ids = input_ids[:, remove_prefix_length:]
- # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
- input_shape = input_ids.shape
- if attention_mask is None:
- attention_mask = input_ids.new_ones(input_shape)
- return {
- "input_ids": input_ids,
- "attention_mask": attention_mask,
- "pixel_values": kwargs.get("pixel_values", None),
- "past_key_values": past_key_values,
- "use_cache": use_cache,
- }
- def _reorder_cache(self, past_key_values, beam_idx):
- reordered_past = ()
- for layer_past in past_key_values:
- reordered_past += (
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
- )
- return reordered_past
|