modeling_git.py 72 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team.
  3. # All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch GIT model."""
  17. import math
  18. from dataclasses import dataclass
  19. from typing import List, Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import CrossEntropyLoss
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache
  26. from ...file_utils import ModelOutput
  27. from ...generation import GenerationMixin
  28. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPast,
  32. BaseModelOutputWithPooling,
  33. CausalLMOutputWithPast,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  37. from ...utils import (
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. logging,
  41. replace_return_docstrings,
  42. torch_int,
  43. )
  44. from .configuration_git import GitConfig, GitVisionConfig
  45. logger = logging.get_logger(__name__)
  46. _CHECKPOINT_FOR_DOC = "microsoft/git-base"
  47. _CONFIG_FOR_DOC = "GitConfig"
  48. @dataclass
  49. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Git
  50. class GitVisionModelOutput(ModelOutput):
  51. """
  52. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  53. Args:
  54. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  55. The image embeddings obtained by applying the projection layer to the pooler_output.
  56. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  57. Sequence of hidden-states at the output of the last layer of the model.
  58. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  59. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  60. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  61. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  62. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  63. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  64. sequence_length)`.
  65. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  66. heads.
  67. """
  68. image_embeds: Optional[torch.FloatTensor] = None
  69. last_hidden_state: torch.FloatTensor = None
  70. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  71. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  72. class GitEmbeddings(nn.Module):
  73. """Construct the embeddings from word and position embeddings."""
  74. def __init__(self, config):
  75. super().__init__()
  76. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  77. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  78. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  79. # any TensorFlow checkpoint file
  80. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  81. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  82. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  83. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  84. self.register_buffer(
  85. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  86. )
  87. def forward(
  88. self,
  89. input_ids: Optional[torch.LongTensor] = None,
  90. position_ids: Optional[torch.LongTensor] = None,
  91. inputs_embeds: Optional[torch.FloatTensor] = None,
  92. past_key_values_length: int = 0,
  93. ) -> torch.Tensor:
  94. if input_ids is not None:
  95. input_shape = input_ids.size()
  96. else:
  97. input_shape = inputs_embeds.size()[:-1]
  98. seq_length = input_shape[1]
  99. if position_ids is None:
  100. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  101. if inputs_embeds is None:
  102. embeddings = self.word_embeddings(input_ids)
  103. else:
  104. embeddings = inputs_embeds
  105. if self.position_embedding_type == "absolute":
  106. position_embeddings = self.position_embeddings(position_ids)
  107. embeddings += position_embeddings
  108. embeddings = self.LayerNorm(embeddings)
  109. embeddings = self.dropout(embeddings)
  110. return embeddings
  111. class GitSelfAttention(nn.Module):
  112. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  113. super().__init__()
  114. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  115. raise ValueError(
  116. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  117. f"heads ({config.num_attention_heads})"
  118. )
  119. self.layer_idx = layer_idx
  120. if layer_idx is None:
  121. logger.warning_once(
  122. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  123. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  124. "when creating this class."
  125. )
  126. self.num_attention_heads = config.num_attention_heads
  127. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  128. self.all_head_size = self.num_attention_heads * self.attention_head_size
  129. self.image_patch_tokens = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
  130. if config.num_image_with_embedding is not None:
  131. self.image_patch_tokens *= config.num_image_with_embedding
  132. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  133. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  134. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  135. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  136. self.position_embedding_type = position_embedding_type or getattr(
  137. config, "position_embedding_type", "absolute"
  138. )
  139. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  140. self.max_position_embeddings = config.max_position_embeddings
  141. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  142. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  143. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  144. x = x.view(new_x_shape)
  145. return x.permute(0, 2, 1, 3)
  146. def forward(
  147. self,
  148. hidden_states: torch.Tensor,
  149. attention_mask: Optional[torch.FloatTensor] = None,
  150. head_mask: Optional[torch.FloatTensor] = None,
  151. past_key_value: Optional[Cache] = None,
  152. output_attentions: Optional[bool] = False,
  153. pixel_values_present: Optional[bool] = False,
  154. ) -> Tuple[torch.Tensor]:
  155. mixed_query_layer = self.query(hidden_states)
  156. cutoff = self.image_patch_tokens if pixel_values_present else 0
  157. key_layer = self.transpose_for_scores(self.key(hidden_states))
  158. value_layer = self.transpose_for_scores(self.value(hidden_states))
  159. if past_key_value is not None:
  160. # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component.
  161. key_layer_past, value_layer_past = past_key_value.update(
  162. key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx
  163. )
  164. key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2)
  165. value_layer = torch.cat([value_layer[:, :, :cutoff, :], value_layer_past], dim=2)
  166. query_layer = self.transpose_for_scores(mixed_query_layer)
  167. # Take the dot product between "query" and "key" to get the raw attention scores.
  168. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  169. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  170. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  171. if past_key_value is not None:
  172. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  173. -1, 1
  174. )
  175. else:
  176. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  177. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  178. distance = position_ids_l - position_ids_r
  179. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  180. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  181. if self.position_embedding_type == "relative_key":
  182. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  183. attention_scores = attention_scores + relative_position_scores
  184. elif self.position_embedding_type == "relative_key_query":
  185. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  186. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  187. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  188. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  189. if attention_mask is not None:
  190. # Apply the attention mask is (precomputed for all layers in GitModel forward() function)
  191. attention_scores = attention_scores + attention_mask
  192. # Normalize the attention scores to probabilities.
  193. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  194. # This is actually dropping out entire tokens to attend to, which might
  195. # seem a bit unusual, but is taken from the original Transformer paper.
  196. attention_probs = self.dropout(attention_probs)
  197. # Mask heads if we want to
  198. if head_mask is not None:
  199. attention_probs = attention_probs * head_mask
  200. context_layer = torch.matmul(attention_probs, value_layer)
  201. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  202. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  203. context_layer = context_layer.view(new_context_layer_shape)
  204. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  205. outputs = outputs + (past_key_value,)
  206. return outputs
  207. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  208. class GitSelfOutput(nn.Module):
  209. def __init__(self, config):
  210. super().__init__()
  211. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  212. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  213. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  214. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  215. hidden_states = self.dense(hidden_states)
  216. hidden_states = self.dropout(hidden_states)
  217. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  218. return hidden_states
  219. GIT_SELF_ATTENTION_CLASSES = {
  220. "eager": GitSelfAttention,
  221. }
  222. class GitAttention(nn.Module):
  223. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  224. super().__init__()
  225. self.self = GIT_SELF_ATTENTION_CLASSES[config._attn_implementation](
  226. config, position_embedding_type=position_embedding_type, layer_idx=layer_idx
  227. )
  228. self.output = GitSelfOutput(config)
  229. self.pruned_heads = set()
  230. # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
  231. def prune_heads(self, heads):
  232. if len(heads) == 0:
  233. return
  234. heads, index = find_pruneable_heads_and_indices(
  235. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  236. )
  237. # Prune linear layers
  238. self.self.query = prune_linear_layer(self.self.query, index)
  239. self.self.key = prune_linear_layer(self.self.key, index)
  240. self.self.value = prune_linear_layer(self.self.value, index)
  241. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  242. # Update hyper params and store pruned heads
  243. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  244. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  245. self.pruned_heads = self.pruned_heads.union(heads)
  246. def forward(
  247. self,
  248. hidden_states: torch.Tensor,
  249. attention_mask: Optional[torch.FloatTensor] = None,
  250. head_mask: Optional[torch.FloatTensor] = None,
  251. past_key_value: Optional[Cache] = None,
  252. output_attentions: Optional[bool] = False,
  253. pixel_values_present: Optional[bool] = False,
  254. ) -> Tuple[torch.Tensor]:
  255. self_outputs = self.self(
  256. hidden_states,
  257. attention_mask,
  258. head_mask,
  259. past_key_value,
  260. output_attentions,
  261. pixel_values_present,
  262. )
  263. attention_output = self.output(self_outputs[0], hidden_states)
  264. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  265. return outputs
  266. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  267. class GitIntermediate(nn.Module):
  268. def __init__(self, config):
  269. super().__init__()
  270. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  271. if isinstance(config.hidden_act, str):
  272. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  273. else:
  274. self.intermediate_act_fn = config.hidden_act
  275. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  276. hidden_states = self.dense(hidden_states)
  277. hidden_states = self.intermediate_act_fn(hidden_states)
  278. return hidden_states
  279. # Copied from transformers.models.bert.modeling_bert.BertOutput
  280. class GitOutput(nn.Module):
  281. def __init__(self, config):
  282. super().__init__()
  283. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  284. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  285. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  286. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  287. hidden_states = self.dense(hidden_states)
  288. hidden_states = self.dropout(hidden_states)
  289. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  290. return hidden_states
  291. class GitLayer(nn.Module):
  292. def __init__(self, config, layer_idx=None):
  293. super().__init__()
  294. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  295. self.seq_len_dim = 1
  296. self.attention = GitAttention(config, layer_idx=layer_idx)
  297. self.intermediate = GitIntermediate(config)
  298. self.output = GitOutput(config)
  299. def forward(
  300. self,
  301. hidden_states: torch.Tensor,
  302. attention_mask: Optional[torch.FloatTensor] = None,
  303. head_mask: Optional[torch.FloatTensor] = None,
  304. past_key_value: Optional[Cache] = None,
  305. output_attentions: Optional[bool] = False,
  306. pixel_values_present: Optional[bool] = False,
  307. ) -> Tuple[torch.Tensor]:
  308. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  309. self_attention_outputs = self.attention(
  310. hidden_states,
  311. attention_mask,
  312. head_mask,
  313. output_attentions=output_attentions,
  314. past_key_value=past_key_value,
  315. pixel_values_present=pixel_values_present,
  316. )
  317. attention_output = self_attention_outputs[0]
  318. # if decoder, the last output is tuple of self-attn cache
  319. outputs = self_attention_outputs[1:-1]
  320. present_key_value = self_attention_outputs[-1]
  321. layer_output = apply_chunking_to_forward(
  322. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  323. )
  324. outputs = (layer_output,) + outputs
  325. # if decoder, return the attn key/values as the last output
  326. outputs = outputs + (present_key_value,)
  327. return outputs
  328. def feed_forward_chunk(self, attention_output):
  329. intermediate_output = self.intermediate(attention_output)
  330. layer_output = self.output(intermediate_output, attention_output)
  331. return layer_output
  332. class GitEncoder(nn.Module):
  333. def __init__(self, config):
  334. super().__init__()
  335. self.config = config
  336. self.layer = nn.ModuleList([GitLayer(config, i) for i in range(config.num_hidden_layers)])
  337. self.gradient_checkpointing = False
  338. def forward(
  339. self,
  340. hidden_states: torch.Tensor,
  341. attention_mask: Optional[torch.FloatTensor] = None,
  342. head_mask: Optional[torch.FloatTensor] = None,
  343. past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]] = None,
  344. use_cache: Optional[bool] = None,
  345. output_attentions: Optional[bool] = False,
  346. output_hidden_states: Optional[bool] = False,
  347. pixel_values_present: Optional[bool] = False,
  348. return_dict: Optional[bool] = True,
  349. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
  350. if self.gradient_checkpointing and self.training:
  351. if use_cache:
  352. logger.warning_once(
  353. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  354. )
  355. use_cache = False
  356. # kept for BC (non `Cache` `past_key_values` inputs)
  357. return_legacy_cache = False
  358. if use_cache and not isinstance(past_key_values, Cache):
  359. return_legacy_cache = True
  360. if past_key_values is None:
  361. past_key_values = DynamicCache()
  362. else:
  363. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  364. logger.warning_once(
  365. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  366. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  367. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  368. )
  369. all_hidden_states = () if output_hidden_states else None
  370. all_self_attentions = () if output_attentions else None
  371. next_decoder_cache = None
  372. for i, layer_module in enumerate(self.layer):
  373. if output_hidden_states:
  374. all_hidden_states = all_hidden_states + (hidden_states,)
  375. layer_head_mask = head_mask[i] if head_mask is not None else None
  376. if self.gradient_checkpointing and self.training:
  377. layer_outputs = self._gradient_checkpointing_func(
  378. layer_module.__call__,
  379. hidden_states,
  380. attention_mask,
  381. layer_head_mask,
  382. past_key_values,
  383. output_attentions,
  384. )
  385. else:
  386. layer_outputs = layer_module(
  387. hidden_states,
  388. attention_mask,
  389. layer_head_mask,
  390. past_key_values,
  391. output_attentions,
  392. pixel_values_present,
  393. )
  394. hidden_states = layer_outputs[0]
  395. if use_cache:
  396. next_decoder_cache = layer_outputs[-1]
  397. if output_attentions:
  398. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  399. if output_hidden_states:
  400. all_hidden_states = all_hidden_states + (hidden_states,)
  401. next_cache = next_decoder_cache if use_cache else None
  402. if return_legacy_cache:
  403. next_cache = next_cache.to_legacy_cache()
  404. if not return_dict:
  405. return tuple(
  406. v
  407. for v in [
  408. hidden_states,
  409. next_cache,
  410. all_hidden_states,
  411. all_self_attentions,
  412. ]
  413. if v is not None
  414. )
  415. return BaseModelOutputWithPast(
  416. last_hidden_state=hidden_states,
  417. past_key_values=next_cache,
  418. hidden_states=all_hidden_states,
  419. attentions=all_self_attentions,
  420. )
  421. class GitPreTrainedModel(PreTrainedModel):
  422. """
  423. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  424. models.
  425. """
  426. config_class = GitConfig
  427. base_model_prefix = "git"
  428. supports_gradient_checkpointing = True
  429. _supports_cache_class = True
  430. _supports_quantized_cache = True
  431. def _init_weights(self, module):
  432. """Initialize the weights"""
  433. if isinstance(module, GitVisionEmbeddings):
  434. nn.init.normal_(module.class_embedding, mean=0.0, std=self.config.initializer_range)
  435. nn.init.normal_(module.patch_embedding.weight, std=self.config.initializer_range)
  436. nn.init.normal_(module.position_embedding.weight, std=self.config.initializer_range)
  437. if isinstance(module, nn.Linear):
  438. # Slightly different from the TF version which uses truncated_normal for initialization
  439. # cf https://github.com/pytorch/pytorch/pull/5617
  440. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  441. if module.bias is not None:
  442. module.bias.data.zero_()
  443. elif isinstance(module, nn.Embedding):
  444. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  445. if module.padding_idx is not None:
  446. module.weight.data[module.padding_idx].zero_()
  447. elif isinstance(module, nn.LayerNorm):
  448. module.bias.data.zero_()
  449. module.weight.data.fill_(1.0)
  450. GIT_START_DOCSTRING = r"""
  451. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  452. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  453. etc.)
  454. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  455. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  456. and behavior.
  457. Parameters:
  458. config ([`GitConfig`]): Model configuration class with all the parameters of the model.
  459. Initializing with a config file does not load the weights associated with the model, only the
  460. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  461. """
  462. GIT_INPUTS_DOCSTRING = r"""
  463. Args:
  464. input_ids (`torch.LongTensor` of shape `({0})`):
  465. Indices of input sequence tokens in the vocabulary.
  466. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  467. [`PreTrainedTokenizer.__call__`] for details.
  468. [What are input IDs?](../glossary#input-ids)
  469. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  470. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  471. - 1 for tokens that are **not masked**,
  472. - 0 for tokens that are **masked**.
  473. [What are attention masks?](../glossary#attention-mask)
  474. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  475. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  476. config.max_position_embeddings - 1]`.
  477. [What are position IDs?](../glossary#position-ids)
  478. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  479. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  480. [`CLIPImageProcessor.__call__`] for details.
  481. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  482. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  483. - 1 indicates the head is **not masked**,
  484. - 0 indicates the head is **masked**.
  485. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  486. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  487. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  488. model's internal embedding lookup matrix.
  489. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  490. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  491. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  492. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  493. Two formats are allowed:
  494. - a [`~cache_utils.Cache`] instance, see our
  495. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  496. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  497. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  498. cache format.
  499. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  500. legacy cache format will be returned.
  501. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  502. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  503. of shape `(batch_size, sequence_length)`.
  504. output_attentions (`bool`, *optional*):
  505. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  506. tensors for more detail.
  507. output_hidden_states (`bool`, *optional*):
  508. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  509. more detail.
  510. interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
  511. Whether to interpolate the pre-trained position encodings.
  512. return_dict (`bool`, *optional*):
  513. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  514. """
  515. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Git
  516. class GitVisionEmbeddings(nn.Module):
  517. def __init__(self, config: GitVisionConfig):
  518. super().__init__()
  519. self.config = config
  520. self.embed_dim = config.hidden_size
  521. self.image_size = config.image_size
  522. self.patch_size = config.patch_size
  523. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  524. self.patch_embedding = nn.Conv2d(
  525. in_channels=config.num_channels,
  526. out_channels=self.embed_dim,
  527. kernel_size=self.patch_size,
  528. stride=self.patch_size,
  529. bias=False,
  530. )
  531. self.num_patches = (self.image_size // self.patch_size) ** 2
  532. self.num_positions = self.num_patches + 1
  533. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  534. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  535. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  536. """
  537. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  538. images. This method is also adapted to support torch.jit tracing.
  539. Adapted from:
  540. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  541. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  542. """
  543. num_patches = embeddings.shape[1] - 1
  544. position_embedding = self.position_embedding.weight.unsqueeze(0)
  545. num_positions = position_embedding.shape[1] - 1
  546. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  547. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  548. return self.position_embedding(self.position_ids)
  549. class_pos_embed = position_embedding[:, :1]
  550. patch_pos_embed = position_embedding[:, 1:]
  551. dim = embeddings.shape[-1]
  552. new_height = height // self.patch_size
  553. new_width = width // self.patch_size
  554. sqrt_num_positions = torch_int(num_positions**0.5)
  555. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  556. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  557. patch_pos_embed = nn.functional.interpolate(
  558. patch_pos_embed,
  559. size=(new_height, new_width),
  560. mode="bicubic",
  561. align_corners=False,
  562. )
  563. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  564. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  565. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  566. batch_size, _, height, width = pixel_values.shape
  567. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  568. raise ValueError(
  569. f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
  570. )
  571. target_dtype = self.patch_embedding.weight.dtype
  572. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  573. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  574. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  575. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  576. if interpolate_pos_encoding:
  577. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  578. else:
  579. embeddings = embeddings + self.position_embedding(self.position_ids)
  580. return embeddings
  581. # Copied from transformers.models.clip.modeling_clip.CLIPMLP
  582. class GitVisionMLP(nn.Module):
  583. def __init__(self, config):
  584. super().__init__()
  585. self.config = config
  586. self.activation_fn = ACT2FN[config.hidden_act]
  587. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  588. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  589. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  590. hidden_states = self.fc1(hidden_states)
  591. hidden_states = self.activation_fn(hidden_states)
  592. hidden_states = self.fc2(hidden_states)
  593. return hidden_states
  594. # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->GitVision
  595. class GitVisionAttention(nn.Module):
  596. """Multi-headed attention from 'Attention Is All You Need' paper"""
  597. def __init__(self, config):
  598. super().__init__()
  599. self.config = config
  600. self.embed_dim = config.hidden_size
  601. self.num_heads = config.num_attention_heads
  602. self.head_dim = self.embed_dim // self.num_heads
  603. if self.head_dim * self.num_heads != self.embed_dim:
  604. raise ValueError(
  605. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  606. f" {self.num_heads})."
  607. )
  608. self.scale = self.head_dim**-0.5
  609. self.dropout = config.attention_dropout
  610. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  611. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  612. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  613. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  614. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  615. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  616. def forward(
  617. self,
  618. hidden_states: torch.Tensor,
  619. attention_mask: Optional[torch.Tensor] = None,
  620. causal_attention_mask: Optional[torch.Tensor] = None,
  621. output_attentions: Optional[bool] = False,
  622. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  623. """Input shape: Batch x Time x Channel"""
  624. bsz, tgt_len, embed_dim = hidden_states.size()
  625. # get query proj
  626. query_states = self.q_proj(hidden_states) * self.scale
  627. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  628. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  629. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  630. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  631. key_states = key_states.view(*proj_shape)
  632. value_states = value_states.view(*proj_shape)
  633. src_len = key_states.size(1)
  634. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  635. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  636. raise ValueError(
  637. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  638. f" {attn_weights.size()}"
  639. )
  640. # apply the causal_attention_mask first
  641. if causal_attention_mask is not None:
  642. if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  643. raise ValueError(
  644. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  645. f" {causal_attention_mask.size()}"
  646. )
  647. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
  648. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  649. if attention_mask is not None:
  650. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  651. raise ValueError(
  652. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  653. )
  654. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  655. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  656. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  657. if output_attentions:
  658. # this operation is a bit akward, but it's required to
  659. # make sure that attn_weights keeps its gradient.
  660. # In order to do so, attn_weights have to reshaped
  661. # twice and have to be reused in the following
  662. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  663. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  664. else:
  665. attn_weights_reshaped = None
  666. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  667. attn_output = torch.bmm(attn_probs, value_states)
  668. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  669. raise ValueError(
  670. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  671. f" {attn_output.size()}"
  672. )
  673. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  674. attn_output = attn_output.transpose(1, 2)
  675. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  676. attn_output = self.out_proj(attn_output)
  677. return attn_output, attn_weights_reshaped
  678. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->GitVision
  679. class GitVisionEncoderLayer(nn.Module):
  680. def __init__(self, config: GitVisionConfig):
  681. super().__init__()
  682. self.embed_dim = config.hidden_size
  683. self.self_attn = GitVisionAttention(config)
  684. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  685. self.mlp = GitVisionMLP(config)
  686. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  687. def forward(
  688. self,
  689. hidden_states: torch.Tensor,
  690. attention_mask: torch.Tensor,
  691. causal_attention_mask: torch.Tensor,
  692. output_attentions: Optional[bool] = False,
  693. ) -> Tuple[torch.FloatTensor]:
  694. """
  695. Args:
  696. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  697. attention_mask (`torch.FloatTensor`): attention mask of size
  698. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  699. `(config.encoder_attention_heads,)`.
  700. output_attentions (`bool`, *optional*):
  701. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  702. returned tensors for more detail.
  703. """
  704. residual = hidden_states
  705. hidden_states = self.layer_norm1(hidden_states)
  706. hidden_states, attn_weights = self.self_attn(
  707. hidden_states=hidden_states,
  708. attention_mask=attention_mask,
  709. causal_attention_mask=causal_attention_mask,
  710. output_attentions=output_attentions,
  711. )
  712. hidden_states = residual + hidden_states
  713. residual = hidden_states
  714. hidden_states = self.layer_norm2(hidden_states)
  715. hidden_states = self.mlp(hidden_states)
  716. hidden_states = residual + hidden_states
  717. outputs = (hidden_states,)
  718. if output_attentions:
  719. outputs += (attn_weights,)
  720. return outputs
  721. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->GitVision, CLIPConfig
  722. class GitVisionEncoder(nn.Module):
  723. """
  724. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  725. [`GitVisionEncoderLayer`].
  726. Args:
  727. config: GitVisionConfig
  728. """
  729. def __init__(self, config: GitVisionConfig):
  730. super().__init__()
  731. self.config = config
  732. self.layers = nn.ModuleList([GitVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  733. self.gradient_checkpointing = False
  734. def forward(
  735. self,
  736. inputs_embeds,
  737. attention_mask: Optional[torch.Tensor] = None,
  738. causal_attention_mask: Optional[torch.Tensor] = None,
  739. output_attentions: Optional[bool] = None,
  740. output_hidden_states: Optional[bool] = None,
  741. return_dict: Optional[bool] = None,
  742. ) -> Union[Tuple, BaseModelOutput]:
  743. r"""
  744. Args:
  745. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  746. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  747. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  748. than the model's internal embedding lookup matrix.
  749. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  750. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  751. - 1 for tokens that are **not masked**,
  752. - 0 for tokens that are **masked**.
  753. [What are attention masks?](../glossary#attention-mask)
  754. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  755. Causal mask for the text model. Mask values selected in `[0, 1]`:
  756. - 1 for tokens that are **not masked**,
  757. - 0 for tokens that are **masked**.
  758. [What are attention masks?](../glossary#attention-mask)
  759. output_attentions (`bool`, *optional*):
  760. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  761. returned tensors for more detail.
  762. output_hidden_states (`bool`, *optional*):
  763. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  764. for more detail.
  765. return_dict (`bool`, *optional*):
  766. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  767. """
  768. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  769. output_hidden_states = (
  770. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  771. )
  772. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  773. encoder_states = () if output_hidden_states else None
  774. all_attentions = () if output_attentions else None
  775. hidden_states = inputs_embeds
  776. for idx, encoder_layer in enumerate(self.layers):
  777. if output_hidden_states:
  778. encoder_states = encoder_states + (hidden_states,)
  779. if self.gradient_checkpointing and self.training:
  780. layer_outputs = self._gradient_checkpointing_func(
  781. encoder_layer.__call__,
  782. hidden_states,
  783. attention_mask,
  784. causal_attention_mask,
  785. output_attentions,
  786. )
  787. else:
  788. layer_outputs = encoder_layer(
  789. hidden_states,
  790. attention_mask,
  791. causal_attention_mask,
  792. output_attentions=output_attentions,
  793. )
  794. hidden_states = layer_outputs[0]
  795. if output_attentions:
  796. all_attentions = all_attentions + (layer_outputs[1],)
  797. if output_hidden_states:
  798. encoder_states = encoder_states + (hidden_states,)
  799. if not return_dict:
  800. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  801. return BaseModelOutput(
  802. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  803. )
  804. GIT_VISION_INPUTS_DOCSTRING = r"""
  805. Args:
  806. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  807. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  808. [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
  809. output_attentions (`bool`, *optional*):
  810. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  811. tensors for more detail.
  812. output_hidden_states (`bool`, *optional*):
  813. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  814. more detail.
  815. interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
  816. Whether to interpolate the pre-trained position encodings.
  817. return_dict (`bool`, *optional*):
  818. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  819. """
  820. class GitVisionTransformer(nn.Module):
  821. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPEncoder->GitVisionEncoder, AltCLIP->Git
  822. def __init__(self, config: GitVisionConfig):
  823. super().__init__()
  824. self.config = config
  825. embed_dim = config.hidden_size
  826. self.embeddings = GitVisionEmbeddings(config)
  827. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  828. self.encoder = GitVisionEncoder(config)
  829. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  830. @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
  831. @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
  832. def forward(
  833. self,
  834. pixel_values: Optional[torch.FloatTensor] = None,
  835. output_attentions: Optional[bool] = None,
  836. output_hidden_states: Optional[bool] = None,
  837. interpolate_pos_encoding: Optional[bool] = False,
  838. return_dict: Optional[bool] = None,
  839. ) -> Union[Tuple, BaseModelOutput]:
  840. r"""
  841. Returns:
  842. """
  843. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  844. output_hidden_states = (
  845. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  846. )
  847. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  848. if pixel_values is None:
  849. raise ValueError("You have to specify pixel_values")
  850. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  851. hidden_states = self.pre_layrnorm(hidden_states)
  852. encoder_outputs = self.encoder(
  853. inputs_embeds=hidden_states,
  854. output_attentions=output_attentions,
  855. output_hidden_states=output_hidden_states,
  856. return_dict=return_dict,
  857. )
  858. last_hidden_state = encoder_outputs[0]
  859. last_hidden_state = self.post_layernorm(last_hidden_state)
  860. if not return_dict:
  861. return (last_hidden_state,) + encoder_outputs[1:]
  862. return BaseModelOutput(
  863. last_hidden_state=last_hidden_state,
  864. hidden_states=encoder_outputs.hidden_states,
  865. attentions=encoder_outputs.attentions,
  866. )
  867. @add_start_docstrings(
  868. """The vision model from CLIP, used in GIT, without any head or projection on top.""",
  869. GIT_START_DOCSTRING,
  870. )
  871. class GitVisionModel(GitPreTrainedModel):
  872. config_class = GitVisionConfig
  873. main_input_name = "pixel_values"
  874. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP->Git
  875. def __init__(self, config: GitVisionConfig):
  876. super().__init__(config)
  877. self.vision_model = GitVisionTransformer(config)
  878. # Initialize weights and apply final processing
  879. self.post_init()
  880. def get_input_embeddings(self) -> nn.Module:
  881. return self.vision_model.embeddings.patch_embedding
  882. @add_start_docstrings_to_model_forward(GIT_VISION_INPUTS_DOCSTRING)
  883. @replace_return_docstrings(output_type=BaseModelOutput, config_class=GitVisionConfig)
  884. def forward(
  885. self,
  886. pixel_values: Optional[torch.FloatTensor] = None,
  887. output_attentions: Optional[bool] = None,
  888. output_hidden_states: Optional[bool] = None,
  889. interpolate_pos_encoding: bool = False,
  890. return_dict: Optional[bool] = None,
  891. ) -> Union[Tuple, BaseModelOutput]:
  892. r"""
  893. Returns:
  894. Examples:
  895. ```python
  896. >>> from PIL import Image
  897. >>> import requests
  898. >>> from transformers import AutoProcessor, GitVisionModel
  899. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  900. >>> model = GitVisionModel.from_pretrained("microsoft/git-base")
  901. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  902. >>> image = Image.open(requests.get(url, stream=True).raw)
  903. >>> inputs = processor(images=image, return_tensors="pt")
  904. >>> outputs = model(**inputs)
  905. >>> last_hidden_state = outputs.last_hidden_state
  906. ```"""
  907. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  908. return self.vision_model(
  909. pixel_values=pixel_values,
  910. output_attentions=output_attentions,
  911. output_hidden_states=output_hidden_states,
  912. interpolate_pos_encoding=interpolate_pos_encoding,
  913. return_dict=return_dict,
  914. )
  915. class GitProjection(nn.Module):
  916. def __init__(self, config: GitConfig):
  917. super().__init__()
  918. self.config = config
  919. self.visual_projection = nn.Sequential(
  920. nn.Linear(config.vision_config.hidden_size, config.hidden_size),
  921. nn.LayerNorm(config.hidden_size, eps=config.vision_config.layer_norm_eps),
  922. )
  923. def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
  924. return self.visual_projection(embeddings)
  925. @add_start_docstrings(
  926. "The bare GIT Model transformer consisting of a CLIP image encoder and text decoder outputting raw hidden-states"
  927. " without any specific head on top.",
  928. GIT_START_DOCSTRING,
  929. )
  930. class GitModel(GitPreTrainedModel):
  931. def __init__(self, config):
  932. super().__init__(config)
  933. self.config = config
  934. self.embeddings = GitEmbeddings(config)
  935. self.image_encoder = GitVisionModel(config.vision_config)
  936. self.encoder = GitEncoder(config)
  937. self.visual_projection = GitProjection(config)
  938. if config.num_image_with_embedding is not None:
  939. self.img_temperal_embedding = nn.ParameterList(
  940. nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
  941. for _ in range(config.num_image_with_embedding)
  942. )
  943. # Initialize weights and apply final processing
  944. self.post_init()
  945. def get_input_embeddings(self):
  946. return self.embeddings.word_embeddings
  947. def set_input_embeddings(self, value):
  948. self.embeddings.word_embeddings = value
  949. def _prune_heads(self, heads_to_prune):
  950. """
  951. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  952. class PreTrainedModel
  953. """
  954. for layer, heads in heads_to_prune.items():
  955. self.encoder.layer[layer].attention.prune_heads(heads)
  956. def _generate_future_mask(self, size: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
  957. # Default mask is for forward direction. Flip for backward direction.
  958. mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
  959. mask = mask.masked_fill(mask == 1, float("-inf"))
  960. return mask
  961. def create_attention_mask(self, tgt, memory, tgt_mask, past_key_values_length, memory_key_padding_mask=None):
  962. num_tgt = tgt.shape[1]
  963. num_memory = memory.shape[1]
  964. device = tgt.device
  965. dtype = tgt.dtype
  966. top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
  967. top_right = torch.full(
  968. (num_memory, num_tgt + past_key_values_length),
  969. float("-inf"),
  970. device=tgt.device,
  971. dtype=dtype,
  972. )
  973. bottom_left = torch.zeros(
  974. (num_tgt, num_memory),
  975. dtype=dtype,
  976. device=tgt_mask.device,
  977. )
  978. if past_key_values_length > 0:
  979. tgt_mask = torch.zeros(
  980. (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
  981. dtype=dtype,
  982. device=tgt_mask.device,
  983. )
  984. left = torch.cat((top_left, bottom_left), dim=0)
  985. right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
  986. full_attention_mask = torch.cat((left, right), dim=1)[None, :]
  987. if memory_key_padding_mask is None:
  988. memory_key_padding_mask = torch.full((memory.shape[0], memory.shape[1]), fill_value=False, device=device)
  989. # if it is False, it means valid. That is, it is not a padding
  990. if memory_key_padding_mask.dtype != torch.bool:
  991. raise ValueError("Memory key padding mask must be a boolean tensor.")
  992. zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
  993. zero_negative_infinity[memory_key_padding_mask] = float("-inf")
  994. full_attention_mask = full_attention_mask.expand(
  995. (memory_key_padding_mask.shape[0], num_memory + num_tgt, num_memory + past_key_values_length + num_tgt)
  996. )
  997. full_attention_mask = full_attention_mask.clone()
  998. origin_left = full_attention_mask[:, :, :num_memory]
  999. update = zero_negative_infinity[:, None, :]
  1000. full_attention_mask[:, :, :num_memory] = origin_left + update
  1001. # add axis for multi-head
  1002. full_attention_mask = full_attention_mask[:, None, :, :]
  1003. return full_attention_mask
  1004. @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1005. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
  1006. def forward(
  1007. self,
  1008. input_ids: Optional[torch.Tensor] = None,
  1009. attention_mask: Optional[torch.Tensor] = None,
  1010. position_ids: Optional[torch.Tensor] = None,
  1011. pixel_values: Optional[torch.Tensor] = None,
  1012. head_mask: Optional[torch.Tensor] = None,
  1013. inputs_embeds: Optional[torch.Tensor] = None,
  1014. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1015. use_cache: Optional[bool] = None,
  1016. output_attentions: Optional[bool] = None,
  1017. output_hidden_states: Optional[bool] = None,
  1018. interpolate_pos_encoding: bool = False,
  1019. return_dict: Optional[bool] = None,
  1020. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
  1021. r"""
  1022. use_cache (`bool`, *optional*):
  1023. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1024. `past_key_values`).
  1025. Returns:
  1026. Examples:
  1027. ```python
  1028. >>> from transformers import AutoProcessor, AutoModel
  1029. >>> import requests
  1030. >>> from PIL import Image
  1031. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base")
  1032. >>> model = AutoModel.from_pretrained("microsoft/git-base")
  1033. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1034. >>> image = Image.open(requests.get(url, stream=True).raw)
  1035. >>> text = "this is an image of two cats"
  1036. >>> inputs = processor(images=image, text=text, return_tensors="pt")
  1037. >>> outputs = model(**inputs)
  1038. >>> last_hidden_state = outputs.last_hidden_state
  1039. ```"""
  1040. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1041. output_hidden_states = (
  1042. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1043. )
  1044. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1045. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1046. if input_ids is not None and inputs_embeds is not None:
  1047. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1048. elif input_ids is not None:
  1049. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1050. input_shape = input_ids.size()
  1051. elif inputs_embeds is not None:
  1052. input_shape = inputs_embeds.size()[:-1]
  1053. else:
  1054. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1055. seq_length = input_shape[1]
  1056. # past_key_values_length
  1057. past_key_values_length = 0
  1058. if past_key_values is not None:
  1059. past_key_values_length = (
  1060. past_key_values[0][0].shape[2]
  1061. if not isinstance(past_key_values, Cache)
  1062. else past_key_values.get_seq_length()
  1063. )
  1064. # Prepare head mask if needed
  1065. # 1.0 in head_mask indicate we keep the head
  1066. # attention_probs has shape bsz x n_heads x N x N
  1067. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  1068. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1069. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  1070. projected_visual_features = None
  1071. if pixel_values is not None:
  1072. if pixel_values.ndim == 4:
  1073. # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
  1074. visual_features = self.image_encoder(
  1075. pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
  1076. ).last_hidden_state
  1077. elif pixel_values.ndim == 5:
  1078. # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
  1079. visual_features = []
  1080. for frame_idx in range(pixel_values.shape[1]):
  1081. visual_features_frame = self.image_encoder(
  1082. pixel_values[:, frame_idx, :, :], interpolate_pos_encoding=interpolate_pos_encoding
  1083. ).last_hidden_state
  1084. visual_features_frame += self.img_temperal_embedding[frame_idx]
  1085. visual_features.append(visual_features_frame)
  1086. # finally, concatenate all features along sequence dimension
  1087. visual_features = torch.cat(visual_features, dim=1)
  1088. else:
  1089. raise ValueError("pixel_values must be of rank 4 or 5")
  1090. projected_visual_features = self.visual_projection(visual_features)
  1091. embedding_output = self.embeddings(
  1092. input_ids=input_ids,
  1093. position_ids=position_ids,
  1094. inputs_embeds=inputs_embeds,
  1095. past_key_values_length=past_key_values_length,
  1096. )
  1097. if projected_visual_features is None:
  1098. projected_visual_features = torch.zeros(
  1099. (embedding_output.shape[0], 0, embedding_output.shape[2]),
  1100. dtype=embedding_output.dtype,
  1101. device=embedding_output.device,
  1102. )
  1103. # Repeat visual features to match embedding batch size.
  1104. projected_visual_features = projected_visual_features.repeat(
  1105. embedding_output.size(0) // projected_visual_features.size(0), 1, 1
  1106. )
  1107. # concatenate patch token and text token embeddings
  1108. hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
  1109. # By default, an additive causal mask is created
  1110. # for masking the future (one direction).
  1111. tgt_mask = self._generate_future_mask(seq_length, embedding_output.dtype, embedding_output.device)
  1112. # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
  1113. combined_attention_mask = self.create_attention_mask(
  1114. tgt=embedding_output,
  1115. memory=projected_visual_features,
  1116. tgt_mask=tgt_mask,
  1117. past_key_values_length=past_key_values_length,
  1118. )
  1119. if attention_mask is not None:
  1120. # if the user provides an attention mask, we add it to the default one
  1121. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1122. expanded_attn_mask = _prepare_4d_attention_mask(
  1123. attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
  1124. ).to(embedding_output.device)
  1125. if past_key_values_length > 0:
  1126. expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
  1127. else:
  1128. combined_attention_mask[:, :, -input_shape[1] :, -input_shape[1] :] += expanded_attn_mask
  1129. encoder_outputs = self.encoder(
  1130. hidden_states,
  1131. attention_mask=combined_attention_mask,
  1132. head_mask=head_mask,
  1133. past_key_values=past_key_values,
  1134. use_cache=use_cache,
  1135. output_attentions=output_attentions,
  1136. output_hidden_states=output_hidden_states,
  1137. return_dict=return_dict,
  1138. pixel_values_present=pixel_values is not None,
  1139. )
  1140. sequence_output = encoder_outputs[0]
  1141. if not return_dict:
  1142. return (sequence_output,) + encoder_outputs[1:]
  1143. return BaseModelOutputWithPast(
  1144. last_hidden_state=sequence_output,
  1145. past_key_values=encoder_outputs.past_key_values,
  1146. hidden_states=encoder_outputs.hidden_states,
  1147. attentions=encoder_outputs.attentions,
  1148. )
  1149. @add_start_docstrings(
  1150. """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING
  1151. )
  1152. class GitForCausalLM(GitPreTrainedModel, GenerationMixin):
  1153. _tied_weights_keys = ["output.weight"]
  1154. def __init__(self, config):
  1155. super().__init__(config)
  1156. self.git = GitModel(config)
  1157. self.output = nn.Linear(config.hidden_size, config.vocab_size)
  1158. # Initialize weights and apply final processing
  1159. self.post_init()
  1160. def get_output_embeddings(self):
  1161. return self.output
  1162. def set_output_embeddings(self, new_embeddings):
  1163. self.output = new_embeddings
  1164. @add_start_docstrings_to_model_forward(GIT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1165. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1166. def forward(
  1167. self,
  1168. input_ids: Optional[torch.Tensor] = None,
  1169. attention_mask: Optional[torch.Tensor] = None,
  1170. position_ids: Optional[torch.Tensor] = None,
  1171. pixel_values: Optional[torch.Tensor] = None,
  1172. head_mask: Optional[torch.Tensor] = None,
  1173. inputs_embeds: Optional[torch.Tensor] = None,
  1174. labels: Optional[torch.Tensor] = None,
  1175. past_key_values: Optional[Union[Cache, List[torch.Tensor]]] = None,
  1176. use_cache: Optional[bool] = None,
  1177. output_attentions: Optional[bool] = None,
  1178. output_hidden_states: Optional[bool] = None,
  1179. interpolate_pos_encoding: bool = False,
  1180. return_dict: Optional[bool] = None,
  1181. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
  1182. r"""
  1183. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1184. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1185. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1186. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1187. use_cache (`bool`, *optional*):
  1188. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  1189. `past_key_values`).
  1190. Returns:
  1191. Examples:
  1192. Image captioning example:
  1193. ```python
  1194. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1195. >>> import requests
  1196. >>> from PIL import Image
  1197. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
  1198. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
  1199. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1200. >>> image = Image.open(requests.get(url, stream=True).raw)
  1201. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  1202. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  1203. >>> generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1204. >>> print(generated_caption)
  1205. two cats sleeping on a pink blanket next to remotes.
  1206. ```
  1207. Visual question answering (VQA) example:
  1208. ```python
  1209. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1210. >>> from huggingface_hub import hf_hub_download
  1211. >>> from PIL import Image
  1212. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
  1213. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
  1214. >>> file_path = hf_hub_download(repo_id="nielsr/textvqa-sample", filename="bus.png", repo_type="dataset")
  1215. >>> image = Image.open(file_path).convert("RGB")
  1216. >>> pixel_values = processor(images=image, return_tensors="pt").pixel_values
  1217. >>> question = "what does the front of the bus say at the top?"
  1218. >>> input_ids = processor(text=question, add_special_tokens=False).input_ids
  1219. >>> input_ids = [processor.tokenizer.cls_token_id] + input_ids
  1220. >>> input_ids = torch.tensor(input_ids).unsqueeze(0)
  1221. >>> generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
  1222. >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True))
  1223. ['what does the front of the bus say at the top? special']
  1224. ```
  1225. Video captioning example:
  1226. ```python
  1227. >>> import av
  1228. >>> import numpy as np
  1229. >>> from PIL import Image
  1230. >>> from huggingface_hub import hf_hub_download
  1231. >>> from transformers import AutoProcessor, AutoModelForCausalLM
  1232. >>> processor = AutoProcessor.from_pretrained("microsoft/git-base-vatex")
  1233. >>> model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vatex")
  1234. >>> # set seed for reproducability
  1235. >>> np.random.seed(45)
  1236. >>> def read_video_pyav(container, indices):
  1237. ... '''
  1238. ... Decode the video with PyAV decoder.
  1239. ... Args:
  1240. ... container (`av.container.input.InputContainer`): PyAV container.
  1241. ... indices (`List[int]`): List of frame indices to decode.
  1242. ... Returns:
  1243. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  1244. ... '''
  1245. ... frames = []
  1246. ... container.seek(0)
  1247. ... start_index = indices[0]
  1248. ... end_index = indices[-1]
  1249. ... for i, frame in enumerate(container.decode(video=0)):
  1250. ... if i > end_index:
  1251. ... break
  1252. ... if i >= start_index and i in indices:
  1253. ... frames.append(frame)
  1254. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  1255. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  1256. ... '''
  1257. ... Sample a given number of frame indices from the video.
  1258. ... Args:
  1259. ... clip_len (`int`): Total number of frames to sample.
  1260. ... frame_sample_rate (`int`): Sample every n-th frame.
  1261. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  1262. ... Returns:
  1263. ... indices (`List[int]`): List of sampled frame indices
  1264. ... '''
  1265. ... converted_len = int(clip_len * frame_sample_rate)
  1266. ... end_idx = np.random.randint(converted_len, seg_len)
  1267. ... start_idx = end_idx - converted_len
  1268. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  1269. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  1270. ... return indices
  1271. >>> # load video
  1272. >>> file_path = hf_hub_download(
  1273. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  1274. ... )
  1275. >>> container = av.open(file_path)
  1276. >>> # sample frames
  1277. >>> num_frames = model.config.num_image_with_embedding
  1278. >>> indices = sample_frame_indices(
  1279. ... clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
  1280. ... )
  1281. >>> frames = read_video_pyav(container, indices)
  1282. >>> pixel_values = processor(images=list(frames), return_tensors="pt").pixel_values
  1283. >>> generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
  1284. >>> print("Generated caption:", processor.batch_decode(generated_ids, skip_special_tokens=True))
  1285. Generated caption: ['a woman is sitting at a table and she is talking about the food she is holding.']
  1286. ```
  1287. """
  1288. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1289. if labels is not None:
  1290. use_cache = False
  1291. outputs = self.git(
  1292. input_ids,
  1293. attention_mask=attention_mask,
  1294. position_ids=position_ids,
  1295. pixel_values=pixel_values,
  1296. head_mask=head_mask,
  1297. inputs_embeds=inputs_embeds,
  1298. past_key_values=past_key_values,
  1299. use_cache=use_cache,
  1300. output_attentions=output_attentions,
  1301. output_hidden_states=output_hidden_states,
  1302. interpolate_pos_encoding=interpolate_pos_encoding,
  1303. return_dict=return_dict,
  1304. )
  1305. sequence_output = outputs[0]
  1306. logits = self.output(sequence_output)
  1307. loss = None
  1308. if labels is not None:
  1309. # we are doing next-token prediction; shift prediction scores and input ids by one
  1310. num_image_tokens = self.git.encoder.layer[0].attention.self.image_patch_tokens
  1311. shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
  1312. labels = labels[:, 1:].contiguous()
  1313. loss_fct = CrossEntropyLoss()
  1314. loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
  1315. if not return_dict:
  1316. output = (logits,) + outputs[1:]
  1317. return ((loss,) + output) if loss is not None else output
  1318. return CausalLMOutputWithPast(
  1319. loss=loss,
  1320. logits=logits,
  1321. past_key_values=outputs.past_key_values,
  1322. hidden_states=outputs.hidden_states,
  1323. attentions=outputs.attentions,
  1324. )
  1325. def prepare_inputs_for_generation(
  1326. self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs
  1327. ):
  1328. # Overwritten -- `git` has special cache handling and doesn't support generating from `inputs_embeds` atm
  1329. # cut decoder_input_ids if past_key_values is used
  1330. if past_key_values is not None:
  1331. past_length = past_key_values.get_seq_length()
  1332. # Some generation methods already pass only the last input ID
  1333. if input_ids.shape[1] > past_length:
  1334. remove_prefix_length = past_length
  1335. else:
  1336. # Default to old behavior: keep only final ID
  1337. remove_prefix_length = input_ids.shape[1] - 1
  1338. input_ids = input_ids[:, remove_prefix_length:]
  1339. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1340. input_shape = input_ids.shape
  1341. if attention_mask is None:
  1342. attention_mask = input_ids.new_ones(input_shape)
  1343. return {
  1344. "input_ids": input_ids,
  1345. "attention_mask": attention_mask,
  1346. "pixel_values": kwargs.get("pixel_values", None),
  1347. "past_key_values": past_key_values,
  1348. "use_cache": use_cache,
  1349. }
  1350. def _reorder_cache(self, past_key_values, beam_idx):
  1351. reordered_past = ()
  1352. for layer_past in past_key_values:
  1353. reordered_past += (
  1354. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1355. )
  1356. return reordered_past