modeling_kosmos2.py 96 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111
  1. # coding=utf-8
  2. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch KOSMOS-2 model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, List, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...generation import GenerationMixin
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPastAndCrossAttentions,
  28. BaseModelOutputWithPooling,
  29. CausalLMOutputWithCrossAttentions,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import (
  33. ModelOutput,
  34. add_start_docstrings,
  35. add_start_docstrings_to_model_forward,
  36. logging,
  37. replace_return_docstrings,
  38. torch_int,
  39. )
  40. from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig
  41. logger = logging.get_logger(__name__)
  42. _CONFIG_FOR_DOC = Kosmos2Config
  43. def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
  44. """
  45. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  46. """
  47. bsz, src_len = mask.size()
  48. tgt_len = tgt_len if tgt_len is not None else src_len
  49. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  50. inverted_mask = 1.0 - expanded_mask
  51. return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  52. def _make_causal_mask(
  53. input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
  54. ):
  55. """
  56. Make causal mask used for bi-directional self-attention.
  57. """
  58. bsz, tgt_len = input_ids_shape
  59. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  60. mask_cond = torch.arange(mask.size(-1), device=device)
  61. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  62. mask = mask.to(dtype)
  63. if past_key_values_length > 0:
  64. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  65. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  66. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  67. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  68. """
  69. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  70. are ignored. This is modified from fairseq's `utils.make_positions`.
  71. Args:
  72. x: torch.Tensor x:
  73. Returns: torch.Tensor
  74. """
  75. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  76. mask = input_ids.ne(padding_idx).int()
  77. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  78. return incremental_indices.long() + padding_idx
  79. KOSMOS2_START_DOCSTRING = r"""
  80. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  81. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  82. etc.)
  83. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  84. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  85. and behavior.
  86. Parameters:
  87. config ([`Kosmos2Config`]): Model configuration class with all the parameters of the model.
  88. Initializing with a config file does not load the weights associated with the model, only the
  89. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  90. """
  91. KOSMOS2_VISION_INPUTS_DOCSTRING = r"""
  92. Args:
  93. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  94. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  95. [`CLIPImageProcessor.__call__`] for details.
  96. output_attentions (`bool`, *optional*):
  97. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  98. tensors for more detail.
  99. output_hidden_states (`bool`, *optional*):
  100. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  101. more detail.
  102. interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
  103. Whether to interpolate the pre-trained position encodings.
  104. return_dict (`bool`, *optional*):
  105. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  106. """
  107. KOSMOS2_TEXT_INPUTS_DOCSTRING = r"""
  108. Args:
  109. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  110. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  111. it.
  112. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  113. [`PreTrainedTokenizer.__call__`] for details.
  114. [What are input IDs?](../glossary#input-ids)
  115. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  116. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  117. - 1 for tokens that are **not masked**,
  118. - 0 for tokens that are **masked**.
  119. [What are attention masks?](../glossary#attention-mask)
  120. image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  121. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  122. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  123. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  124. 1]`:
  125. - 1 for places where to put the image features,
  126. - 0 for places that are not for image features (i.e. for text tokens).
  127. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  128. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  129. the model is configured as a decoder.
  130. encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  131. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  132. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  133. - 1 for tokens that are **not masked**,
  134. - 0 for tokens that are **masked**.
  135. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  136. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  137. - 1 indicates the head is **not masked**,
  138. - 0 indicates the head is **masked**.
  139. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  140. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  141. - 1 indicates the head is **not masked**,
  142. - 0 indicates the head is **masked**.
  143. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  144. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  145. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  146. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  147. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  148. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  149. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  150. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  151. model's internal embedding lookup matrix.
  152. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  153. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  154. config.max_position_embeddings - 1]`.
  155. [What are position IDs?](../glossary#position-ids)
  156. use_cache (`bool`, *optional*):
  157. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  158. `past_key_values`).
  159. output_attentions (`bool`, *optional*):
  160. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  161. tensors for more detail.
  162. output_hidden_states (`bool`, *optional*):
  163. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  164. more detail.
  165. return_dict (`bool`, *optional*):
  166. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  167. """
  168. KOSMOS2_INPUTS_DOCSTRING = r"""
  169. Args:
  170. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  171. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  172. [`CLIPImageProcessor.__call__`] for details.
  173. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  174. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  175. it.
  176. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  177. [`PreTrainedTokenizer.__call__`] for details.
  178. [What are input IDs?](../glossary#input-ids)
  179. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  180. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  181. 1]`:
  182. - 1 for places where to put the image features,
  183. - 0 for places that are not for image features (i.e. for text tokens).
  184. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  185. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  186. - 1 for tokens that are **not masked**,
  187. - 0 for tokens that are **masked**.
  188. [What are attention masks?](../glossary#attention-mask)
  189. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  190. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  191. - 1 indicates the head is **not masked**,
  192. - 0 indicates the head is **masked**.
  193. past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  194. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  195. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  196. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  197. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  198. image_embeds: (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  199. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  200. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  201. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  202. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  203. model's internal embedding lookup matrix.
  204. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  205. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  206. config.max_position_embeddings - 1]`.
  207. [What are position IDs?](../glossary#position-ids)
  208. use_cache (`bool`, *optional*):
  209. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  210. `past_key_values`).
  211. output_attentions (`bool`, *optional*):
  212. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  213. tensors for more detail.
  214. output_hidden_states (`bool`, *optional*):
  215. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  216. more detail.
  217. interpolate_pos_encoding (`bool`, *optional*, defaults `False`):
  218. Whether to interpolate the pre-trained position encodings.
  219. return_dict (`bool`, *optional*):
  220. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  221. """
  222. @dataclass
  223. class Kosmos2ModelOutput(ModelOutput):
  224. """
  225. Base class for text model's outputs that also contains a pooling of the last hidden states.
  226. Args:
  227. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  228. Sequence of hidden-states at the output of the last layer of the model.
  229. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  230. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  231. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  232. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  233. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  234. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  235. sequence_length)`.
  236. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  237. heads.
  238. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  239. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  240. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  241. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  242. sequence_length)`.
  243. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  244. the weighted average in the self-attention heads.
  245. vision_model_output(`BaseModelOutputWithPooling`, *optional*):
  246. The output of the [`Kosmos2VisionModel`].
  247. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  248. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  249. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  250. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  251. encoder_sequence_length, embed_size_per_head)`.
  252. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  253. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  254. input) to speed up sequential decoding.
  255. """
  256. last_hidden_state: torch.FloatTensor = None
  257. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  258. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  259. attentions: Optional[Tuple[torch.FloatTensor]] = None
  260. image_embeds: Optional[torch.FloatTensor] = None
  261. projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
  262. vision_model_output: BaseModelOutputWithPooling = None
  263. def to_tuple(self) -> Tuple[Any]:
  264. return tuple(
  265. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  266. for k in self.keys()
  267. )
  268. @dataclass
  269. class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
  270. """
  271. Model output class for `Kosmos2ForConditionalGeneration`.
  272. Args:
  273. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  274. Language modeling loss (for next-token prediction).
  275. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  276. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  277. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  278. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  279. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  280. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  281. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  282. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  283. sequence_length)`.
  284. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  285. heads.
  286. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  287. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  288. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  289. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  290. sequence_length)`.
  291. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  292. the weighted average in the self-attention heads.
  293. vision_model_output(`BaseModelOutputWithPooling`, *optional*):
  294. The output of the [`Kosmos2VisionModel`].
  295. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  296. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  297. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if
  298. `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,
  299. encoder_sequence_length, embed_size_per_head)`.
  300. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  301. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  302. input) to speed up sequential decoding.
  303. """
  304. loss: Optional[torch.FloatTensor] = None
  305. logits: torch.FloatTensor = None
  306. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  307. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  308. attentions: Optional[Tuple[torch.FloatTensor]] = None
  309. image_embeds: Optional[torch.FloatTensor] = None
  310. projection_attentions: Optional[Tuple[torch.FloatTensor]] = None
  311. vision_model_output: BaseModelOutputWithPooling = None
  312. def to_tuple(self) -> Tuple[Any]:
  313. return tuple(
  314. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  315. for k in self.keys()
  316. )
  317. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2
  318. class Kosmos2VisionEmbeddings(nn.Module):
  319. def __init__(self, config: Kosmos2VisionConfig):
  320. super().__init__()
  321. self.config = config
  322. self.embed_dim = config.hidden_size
  323. self.image_size = config.image_size
  324. self.patch_size = config.patch_size
  325. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  326. self.patch_embedding = nn.Conv2d(
  327. in_channels=config.num_channels,
  328. out_channels=self.embed_dim,
  329. kernel_size=self.patch_size,
  330. stride=self.patch_size,
  331. bias=False,
  332. )
  333. self.num_patches = (self.image_size // self.patch_size) ** 2
  334. self.num_positions = self.num_patches + 1
  335. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  336. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  337. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  338. """
  339. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  340. images. This method is also adapted to support torch.jit tracing.
  341. Adapted from:
  342. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  343. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  344. """
  345. num_patches = embeddings.shape[1] - 1
  346. position_embedding = self.position_embedding.weight.unsqueeze(0)
  347. num_positions = position_embedding.shape[1] - 1
  348. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  349. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  350. return self.position_embedding(self.position_ids)
  351. class_pos_embed = position_embedding[:, :1]
  352. patch_pos_embed = position_embedding[:, 1:]
  353. dim = embeddings.shape[-1]
  354. new_height = height // self.patch_size
  355. new_width = width // self.patch_size
  356. sqrt_num_positions = torch_int(num_positions**0.5)
  357. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  358. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  359. patch_pos_embed = nn.functional.interpolate(
  360. patch_pos_embed,
  361. size=(new_height, new_width),
  362. mode="bicubic",
  363. align_corners=False,
  364. )
  365. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  366. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  367. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  368. batch_size, _, height, width = pixel_values.shape
  369. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  370. raise ValueError(
  371. f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
  372. )
  373. target_dtype = self.patch_embedding.weight.dtype
  374. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  375. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  376. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  377. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  378. if interpolate_pos_encoding:
  379. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  380. else:
  381. embeddings = embeddings + self.position_embedding(self.position_ids)
  382. return embeddings
  383. # Copied from transformers.models.clip.modeling_clip.CLIPAttention with CLIP->Kosmos2Vision
  384. class Kosmos2VisionAttention(nn.Module):
  385. """Multi-headed attention from 'Attention Is All You Need' paper"""
  386. def __init__(self, config):
  387. super().__init__()
  388. self.config = config
  389. self.embed_dim = config.hidden_size
  390. self.num_heads = config.num_attention_heads
  391. self.head_dim = self.embed_dim // self.num_heads
  392. if self.head_dim * self.num_heads != self.embed_dim:
  393. raise ValueError(
  394. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  395. f" {self.num_heads})."
  396. )
  397. self.scale = self.head_dim**-0.5
  398. self.dropout = config.attention_dropout
  399. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  400. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  401. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  402. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  403. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  404. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  405. def forward(
  406. self,
  407. hidden_states: torch.Tensor,
  408. attention_mask: Optional[torch.Tensor] = None,
  409. causal_attention_mask: Optional[torch.Tensor] = None,
  410. output_attentions: Optional[bool] = False,
  411. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  412. """Input shape: Batch x Time x Channel"""
  413. bsz, tgt_len, embed_dim = hidden_states.size()
  414. # get query proj
  415. query_states = self.q_proj(hidden_states) * self.scale
  416. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  417. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  418. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  419. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  420. key_states = key_states.view(*proj_shape)
  421. value_states = value_states.view(*proj_shape)
  422. src_len = key_states.size(1)
  423. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  424. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  425. raise ValueError(
  426. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  427. f" {attn_weights.size()}"
  428. )
  429. # apply the causal_attention_mask first
  430. if causal_attention_mask is not None:
  431. if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  432. raise ValueError(
  433. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  434. f" {causal_attention_mask.size()}"
  435. )
  436. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
  437. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  438. if attention_mask is not None:
  439. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  440. raise ValueError(
  441. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  442. )
  443. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  444. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  445. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  446. if output_attentions:
  447. # this operation is a bit akward, but it's required to
  448. # make sure that attn_weights keeps its gradient.
  449. # In order to do so, attn_weights have to reshaped
  450. # twice and have to be reused in the following
  451. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  452. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  453. else:
  454. attn_weights_reshaped = None
  455. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  456. attn_output = torch.bmm(attn_probs, value_states)
  457. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  458. raise ValueError(
  459. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  460. f" {attn_output.size()}"
  461. )
  462. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  463. attn_output = attn_output.transpose(1, 2)
  464. attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
  465. attn_output = self.out_proj(attn_output)
  466. return attn_output, attn_weights_reshaped
  467. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
  468. class Kosmos2VisionMLP(nn.Module):
  469. def __init__(self, config):
  470. super().__init__()
  471. self.config = config
  472. self.activation_fn = ACT2FN[config.hidden_act]
  473. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  474. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  475. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  476. hidden_states = self.fc1(hidden_states)
  477. hidden_states = self.activation_fn(hidden_states)
  478. hidden_states = self.fc2(hidden_states)
  479. return hidden_states
  480. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
  481. class Kosmos2VisionEncoderLayer(nn.Module):
  482. def __init__(self, config: Kosmos2VisionConfig):
  483. super().__init__()
  484. self.embed_dim = config.hidden_size
  485. self.self_attn = Kosmos2VisionAttention(config)
  486. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  487. self.mlp = Kosmos2VisionMLP(config)
  488. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  489. def forward(
  490. self,
  491. hidden_states: torch.Tensor,
  492. attention_mask: torch.Tensor,
  493. causal_attention_mask: torch.Tensor,
  494. output_attentions: Optional[bool] = False,
  495. ) -> Tuple[torch.FloatTensor]:
  496. """
  497. Args:
  498. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  499. attention_mask (`torch.FloatTensor`): attention mask of size
  500. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  501. `(config.encoder_attention_heads,)`.
  502. output_attentions (`bool`, *optional*):
  503. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  504. returned tensors for more detail.
  505. """
  506. residual = hidden_states
  507. hidden_states = self.layer_norm1(hidden_states)
  508. hidden_states, attn_weights = self.self_attn(
  509. hidden_states=hidden_states,
  510. attention_mask=attention_mask,
  511. causal_attention_mask=causal_attention_mask,
  512. output_attentions=output_attentions,
  513. )
  514. hidden_states = residual + hidden_states
  515. residual = hidden_states
  516. hidden_states = self.layer_norm2(hidden_states)
  517. hidden_states = self.mlp(hidden_states)
  518. hidden_states = residual + hidden_states
  519. outputs = (hidden_states,)
  520. if output_attentions:
  521. outputs += (attn_weights,)
  522. return outputs
  523. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Kosmos2Vision
  524. class Kosmos2VisionEncoder(nn.Module):
  525. """
  526. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  527. [`Kosmos2VisionEncoderLayer`].
  528. Args:
  529. config: Kosmos2VisionConfig
  530. """
  531. def __init__(self, config: Kosmos2VisionConfig):
  532. super().__init__()
  533. self.config = config
  534. self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  535. self.gradient_checkpointing = False
  536. def forward(
  537. self,
  538. inputs_embeds,
  539. attention_mask: Optional[torch.Tensor] = None,
  540. causal_attention_mask: Optional[torch.Tensor] = None,
  541. output_attentions: Optional[bool] = None,
  542. output_hidden_states: Optional[bool] = None,
  543. return_dict: Optional[bool] = None,
  544. ) -> Union[Tuple, BaseModelOutput]:
  545. r"""
  546. Args:
  547. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  548. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  549. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  550. than the model's internal embedding lookup matrix.
  551. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  552. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  553. - 1 for tokens that are **not masked**,
  554. - 0 for tokens that are **masked**.
  555. [What are attention masks?](../glossary#attention-mask)
  556. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  557. Causal mask for the text model. Mask values selected in `[0, 1]`:
  558. - 1 for tokens that are **not masked**,
  559. - 0 for tokens that are **masked**.
  560. [What are attention masks?](../glossary#attention-mask)
  561. output_attentions (`bool`, *optional*):
  562. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  563. returned tensors for more detail.
  564. output_hidden_states (`bool`, *optional*):
  565. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  566. for more detail.
  567. return_dict (`bool`, *optional*):
  568. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  569. """
  570. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  571. output_hidden_states = (
  572. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  573. )
  574. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  575. encoder_states = () if output_hidden_states else None
  576. all_attentions = () if output_attentions else None
  577. hidden_states = inputs_embeds
  578. for idx, encoder_layer in enumerate(self.layers):
  579. if output_hidden_states:
  580. encoder_states = encoder_states + (hidden_states,)
  581. if self.gradient_checkpointing and self.training:
  582. layer_outputs = self._gradient_checkpointing_func(
  583. encoder_layer.__call__,
  584. hidden_states,
  585. attention_mask,
  586. causal_attention_mask,
  587. output_attentions,
  588. )
  589. else:
  590. layer_outputs = encoder_layer(
  591. hidden_states,
  592. attention_mask,
  593. causal_attention_mask,
  594. output_attentions=output_attentions,
  595. )
  596. hidden_states = layer_outputs[0]
  597. if output_attentions:
  598. all_attentions = all_attentions + (layer_outputs[1],)
  599. if output_hidden_states:
  600. encoder_states = encoder_states + (hidden_states,)
  601. if not return_dict:
  602. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  603. return BaseModelOutput(
  604. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  605. )
  606. # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward`
  607. class Kosmos2VisionTransformer(nn.Module):
  608. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision
  609. def __init__(self, config: Kosmos2VisionConfig):
  610. super().__init__()
  611. self.config = config
  612. embed_dim = config.hidden_size
  613. self.embeddings = Kosmos2VisionEmbeddings(config)
  614. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  615. self.encoder = Kosmos2VisionEncoder(config)
  616. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  617. def forward(
  618. self,
  619. pixel_values: Optional[torch.FloatTensor] = None,
  620. output_attentions: Optional[bool] = None,
  621. output_hidden_states: Optional[bool] = None,
  622. interpolate_pos_encoding: bool = False,
  623. return_dict: Optional[bool] = None,
  624. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  625. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  626. output_hidden_states = (
  627. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  628. )
  629. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  630. if pixel_values is None:
  631. raise ValueError("You have to specify pixel_values")
  632. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  633. hidden_states = self.pre_layrnorm(hidden_states)
  634. encoder_outputs = self.encoder(
  635. inputs_embeds=hidden_states,
  636. output_attentions=output_attentions,
  637. output_hidden_states=output_hidden_states,
  638. return_dict=return_dict,
  639. )
  640. last_hidden_state = encoder_outputs[0]
  641. pooled_output = last_hidden_state[:, 0, :]
  642. pooled_output = self.post_layernorm(pooled_output)
  643. if not return_dict:
  644. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  645. return BaseModelOutputWithPooling(
  646. last_hidden_state=last_hidden_state,
  647. pooler_output=pooled_output,
  648. hidden_states=encoder_outputs.hidden_states,
  649. attentions=encoder_outputs.attentions,
  650. )
  651. # Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids`
  652. class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
  653. """This module produces sinusoidal positional embeddings of any length."""
  654. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__
  655. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  656. super().__init__()
  657. self.offset = 2
  658. self.embedding_dim = embedding_dim
  659. self.padding_idx = padding_idx
  660. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  661. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights
  662. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  663. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  664. if hasattr(self, "weights"):
  665. # in forward put the weights on the correct dtype and device of the param
  666. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  667. self.register_buffer("weights", emb_weights, persistent=False)
  668. @staticmethod
  669. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding
  670. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  671. """
  672. Build sinusoidal embeddings.
  673. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  674. "Attention Is All You Need".
  675. """
  676. half_dim = embedding_dim // 2
  677. emb = math.log(10000) / (half_dim - 1)
  678. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  679. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  680. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  681. if embedding_dim % 2 == 1:
  682. # zero pad
  683. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  684. if padding_idx is not None:
  685. emb[padding_idx, :] = 0
  686. return emb.to(torch.get_default_dtype())
  687. @torch.no_grad()
  688. def forward(
  689. self,
  690. input_ids: torch.Tensor = None,
  691. inputs_embeds: torch.Tensor = None,
  692. past_key_values_length: int = 0,
  693. position_ids: torch.Tensor = None,
  694. ):
  695. if input_ids is not None:
  696. bsz, seq_len = input_ids.size()
  697. if position_ids is None:
  698. # Create the position ids from the input token ids. Any padded tokens remain padded.
  699. position_ids = create_position_ids_from_input_ids(
  700. input_ids, self.padding_idx, past_key_values_length
  701. ).to(input_ids.device)
  702. else:
  703. bsz, seq_len = inputs_embeds.size()[:-1]
  704. if position_ids is None:
  705. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
  706. # expand embeddings if needed
  707. max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
  708. if max_pos > self.weights.size(0):
  709. self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
  710. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
  711. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds
  712. def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
  713. """
  714. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  715. Args:
  716. inputs_embeds: torch.Tensor
  717. Returns: torch.Tensor
  718. """
  719. input_shape = inputs_embeds.size()[:-1]
  720. sequence_length = input_shape[1]
  721. position_ids = torch.arange(
  722. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  723. )
  724. return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
  725. class KosmosTextAttention(nn.Module):
  726. """Multi-headed attention from 'Attention Is All You Need' paper"""
  727. # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`.
  728. def __init__(
  729. self,
  730. config,
  731. embed_dim: int,
  732. num_heads: int,
  733. dropout: float = 0.0,
  734. is_decoder: bool = False,
  735. add_inner_attn_layernorm: bool = False,
  736. bias: bool = True,
  737. ):
  738. super().__init__()
  739. self.embed_dim = embed_dim
  740. self.num_heads = num_heads
  741. self.dropout = dropout
  742. self.head_dim = embed_dim // num_heads
  743. if (self.head_dim * num_heads) != self.embed_dim:
  744. raise ValueError(
  745. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  746. f" and `num_heads`: {num_heads})."
  747. )
  748. self.scaling = self.head_dim**-0.5
  749. self.is_decoder = is_decoder
  750. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  751. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  752. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  753. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  754. # End opy
  755. self.inner_attn_ln = None
  756. if add_inner_attn_layernorm:
  757. self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  758. def _shape(self, projection: torch.Tensor) -> torch.Tensor:
  759. new_projection_shape = projection.size()[:-1] + (self.num_heads, self.head_dim)
  760. # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
  761. new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
  762. return new_projection
  763. def forward(
  764. self,
  765. hidden_states: torch.Tensor,
  766. encoder_hidden_states: Optional[torch.Tensor] = None,
  767. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  768. attention_mask: Optional[torch.Tensor] = None,
  769. layer_head_mask: Optional[torch.Tensor] = None,
  770. output_attentions: bool = False,
  771. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  772. """Input shape: Batch x Time x Channel"""
  773. # if key_value_states are provided this layer is used as a cross-attention layer
  774. # for the decoder
  775. is_cross_attention = encoder_hidden_states is not None
  776. batch_size, seq_length = hidden_states.shape[:2]
  777. # use encoder_hidden_states if cross attention
  778. current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
  779. # checking that the `sequence_length` of the `past_key_value` is the same as the he provided
  780. # `encoder_hidden_states` to support prefix tuning
  781. if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
  782. # reuse k,v, cross_attentions
  783. key_states = past_key_value[0]
  784. value_states = past_key_value[1]
  785. else:
  786. key_states = self._shape(self.k_proj(current_states))
  787. value_states = self._shape(self.v_proj(current_states))
  788. if past_key_value is not None and not is_cross_attention:
  789. # reuse k, v, self_attention
  790. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  791. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  792. query_states = self._shape(self.q_proj(hidden_states) * self.scaling)
  793. attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2))
  794. if self.is_decoder:
  795. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  796. # Further calls to cross_attention layer can then reuse all cross-attention
  797. # key/value_states (first "if" case)
  798. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  799. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  800. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  801. # if encoder bi-directional self-attention `past_key_value` is always `None`
  802. past_key_value = (key_states, value_states)
  803. src_len = key_states.size(2)
  804. if attention_mask is not None:
  805. if attention_mask.size() != (batch_size, 1, seq_length, src_len):
  806. raise ValueError(
  807. f"Attention mask should be of size {(batch_size, 1, seq_length, src_len)}, but is {attention_mask.size()}"
  808. )
  809. attn_weights = attn_weights + attention_mask
  810. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  811. # Mask heads if we want to
  812. if layer_head_mask is not None:
  813. attn_weights = attn_weights * layer_head_mask
  814. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  815. # attn_output = torch.bmm(attn_probs, value_states) ?
  816. context_states = torch.matmul(attn_weights, value_states)
  817. # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ?
  818. context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1)
  819. if self.inner_attn_ln is not None:
  820. context_states = self.inner_attn_ln(context_states)
  821. attn_output = self.out_proj(context_states)
  822. return attn_output, attn_weights, past_key_value
  823. class Kosmos2TextFFN(nn.Module):
  824. def __init__(self, config: Kosmos2TextConfig):
  825. super().__init__()
  826. self.dropout = config.dropout
  827. self.activation_fn = ACT2FN[config.activation_function]
  828. self.activation_dropout = config.activation_dropout
  829. self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim)
  830. self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim)
  831. self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps)
  832. def forward(self, hidden_states):
  833. hidden_states = self.activation_fn(self.fc1(hidden_states))
  834. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  835. hidden_states = self.ffn_layernorm(hidden_states)
  836. hidden_states = self.fc2(hidden_states)
  837. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  838. return hidden_states
  839. class Kosmos2TextBlock(nn.Module):
  840. def __init__(self, config: Kosmos2TextConfig):
  841. super().__init__()
  842. self.embed_dim = config.embed_dim
  843. self.self_attn = KosmosTextAttention(
  844. config,
  845. embed_dim=self.embed_dim,
  846. num_heads=config.attention_heads,
  847. dropout=config.attention_dropout,
  848. is_decoder=True,
  849. add_inner_attn_layernorm=True,
  850. )
  851. self.dropout = config.dropout
  852. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  853. if config.add_cross_attention:
  854. self.encoder_attn = KosmosTextAttention(
  855. config,
  856. embed_dim=self.embed_dim,
  857. num_heads=config.attention_heads,
  858. dropout=config.attention_dropout,
  859. is_decoder=True,
  860. add_inner_attn_layernorm=False,
  861. )
  862. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  863. self.ffn = Kosmos2TextFFN(config)
  864. self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  865. def forward(
  866. self,
  867. hidden_states: torch.Tensor,
  868. attention_mask: Optional[torch.Tensor] = None,
  869. encoder_hidden_states: Optional[torch.Tensor] = None,
  870. encoder_attention_mask: Optional[torch.Tensor] = None,
  871. layer_head_mask: Optional[torch.Tensor] = None,
  872. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  873. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  874. output_attentions: Optional[bool] = False,
  875. use_cache: Optional[bool] = True,
  876. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  877. residual = hidden_states
  878. # Self Attention
  879. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  880. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  881. hidden_states = self.self_attn_layer_norm(hidden_states)
  882. # add present self-attn cache to positions 1,2 of present_key_value tuple
  883. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  884. hidden_states=hidden_states,
  885. past_key_value=self_attn_past_key_value,
  886. attention_mask=attention_mask,
  887. layer_head_mask=layer_head_mask,
  888. output_attentions=output_attentions,
  889. )
  890. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  891. hidden_states = residual + hidden_states
  892. # Cross-Attention Block
  893. cross_attn_present_key_value = None
  894. cross_attn_weights = None
  895. if encoder_hidden_states is not None:
  896. if not hasattr(self, "encoder_attn"):
  897. raise ValueError(
  898. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  899. " by setting `config.add_cross_attention=True`"
  900. )
  901. residual = hidden_states
  902. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  903. # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
  904. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  905. hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
  906. hidden_states=hidden_states,
  907. encoder_hidden_states=encoder_hidden_states,
  908. attention_mask=encoder_attention_mask,
  909. layer_head_mask=cross_attn_layer_head_mask,
  910. past_key_value=cross_attn_past_key_value,
  911. output_attentions=output_attentions,
  912. )
  913. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  914. hidden_states = residual + hidden_states
  915. # add cross-attn to positions 3,4 of present_key_value tuple
  916. present_key_value = present_key_value + cross_attn_present_key_value
  917. # Fully Connected
  918. residual = hidden_states
  919. hidden_states = self.final_layer_norm(hidden_states)
  920. # FFN
  921. hidden_states = self.ffn(hidden_states)
  922. hidden_states = residual + hidden_states
  923. outputs = (hidden_states,)
  924. if output_attentions:
  925. outputs += (self_attn_weights, cross_attn_weights)
  926. if use_cache:
  927. outputs += (present_key_value,)
  928. return outputs
  929. class Kosmos2TextTransformer(nn.Module):
  930. """
  931. Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].
  932. Args:
  933. config: Kosmos2TextConfig
  934. """
  935. def __init__(self, config: Kosmos2TextConfig):
  936. super().__init__()
  937. self.config = config
  938. self.dropout = config.dropout
  939. self.layerdrop = config.layerdrop
  940. self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0
  941. self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)
  942. self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
  943. num_positions=config.max_position_embeddings,
  944. embedding_dim=config.embed_dim,
  945. padding_idx=config.pad_token_id,
  946. )
  947. self.layers = nn.ModuleList([Kosmos2TextBlock(config) for _ in range(config.layers)])
  948. self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
  949. self.gradient_checkpointing = False
  950. def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
  951. # create causal mask
  952. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  953. combined_attention_mask = None
  954. if input_shape[-1] > 1:
  955. combined_attention_mask = _make_causal_mask(
  956. input_shape,
  957. inputs_embeds.dtype,
  958. device=inputs_embeds.device,
  959. past_key_values_length=past_key_values_length,
  960. )
  961. if attention_mask is not None:
  962. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  963. expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
  964. inputs_embeds.device
  965. )
  966. combined_attention_mask = (
  967. expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
  968. )
  969. return combined_attention_mask
  970. def forward_embedding(
  971. self,
  972. input_ids,
  973. inputs_embeds: torch.Tensor = None,
  974. image_embeds: torch.Tensor = None,
  975. img_input_mask: torch.Tensor = None,
  976. past_key_values_length: int = 0,
  977. position_ids: torch.Tensor = None,
  978. ):
  979. # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`.
  980. if inputs_embeds is None:
  981. inputs_embeds = self.embed_tokens(input_ids)
  982. if image_embeds is not None:
  983. inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view(
  984. -1, image_embeds.size(-1)
  985. )
  986. inputs_embeds = inputs_embeds * self.embed_scale
  987. # embed positions
  988. positions = self.embed_positions(
  989. input_ids=input_ids,
  990. inputs_embeds=inputs_embeds,
  991. past_key_values_length=past_key_values_length,
  992. position_ids=position_ids,
  993. )
  994. positions = positions.to(inputs_embeds.device)
  995. hidden_states = inputs_embeds + positions
  996. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  997. return hidden_states
  998. def forward(
  999. self,
  1000. input_ids: Optional[torch.Tensor] = None,
  1001. attention_mask: Optional[torch.Tensor] = None,
  1002. image_embeds: Optional[torch.Tensor] = None,
  1003. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1004. encoder_hidden_states: Optional[torch.Tensor] = None,
  1005. encoder_attention_mask: Optional[torch.Tensor] = None,
  1006. head_mask: Optional[torch.Tensor] = None,
  1007. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1008. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1009. inputs_embeds: Optional[torch.Tensor] = None,
  1010. position_ids: Optional[torch.Tensor] = None,
  1011. use_cache: Optional[bool] = None,
  1012. output_attentions: Optional[bool] = None,
  1013. output_hidden_states: Optional[bool] = None,
  1014. return_dict: Optional[bool] = None,
  1015. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1016. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1017. output_hidden_states = (
  1018. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1019. )
  1020. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1021. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1022. if input_ids is not None and inputs_embeds is not None:
  1023. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1024. elif input_ids is not None:
  1025. input_shape = input_ids.shape
  1026. input_ids = input_ids.view(-1, input_shape[-1])
  1027. elif inputs_embeds is not None:
  1028. input_shape = inputs_embeds.size()[:-1]
  1029. else:
  1030. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1031. # past_key_values_length
  1032. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  1033. # We don't need img info. when `past_key_values_length` > 0
  1034. if past_key_values_length > 0:
  1035. image_embeds = None
  1036. image_embeds_position_mask = None
  1037. hidden_states = self.forward_embedding(
  1038. input_ids=input_ids,
  1039. inputs_embeds=inputs_embeds,
  1040. image_embeds=image_embeds,
  1041. img_input_mask=image_embeds_position_mask,
  1042. past_key_values_length=past_key_values_length,
  1043. position_ids=position_ids,
  1044. )
  1045. attention_mask = self._prepare_decoder_attention_mask(
  1046. attention_mask, input_shape, hidden_states, past_key_values_length
  1047. )
  1048. # expand encoder attention mask
  1049. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1050. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1051. encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
  1052. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  1053. if self.gradient_checkpointing and self.training:
  1054. if use_cache:
  1055. logger.warning_once(
  1056. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1057. )
  1058. use_cache = False
  1059. # decoder layers
  1060. all_hidden_states = () if output_hidden_states else None
  1061. all_self_attns = () if output_attentions else None
  1062. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  1063. present_key_value_states = () if use_cache else None
  1064. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1065. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  1066. if attn_mask is not None:
  1067. if attn_mask.size()[0] != (len(self.layers)):
  1068. raise ValueError(
  1069. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1070. f" {head_mask.size()[0]}."
  1071. )
  1072. for idx, decoder_layer in enumerate(self.layers):
  1073. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  1074. if output_hidden_states:
  1075. all_hidden_states += (hidden_states,)
  1076. if self.training:
  1077. dropout_probability = torch.rand([])
  1078. if dropout_probability < self.layerdrop:
  1079. continue
  1080. past_key_value = past_key_values[idx] if past_key_values is not None else None
  1081. if self.gradient_checkpointing and self.training:
  1082. layer_outputs = self._gradient_checkpointing_func(
  1083. decoder_layer.__call__,
  1084. hidden_states,
  1085. attention_mask,
  1086. encoder_hidden_states,
  1087. encoder_attention_mask,
  1088. head_mask[idx] if head_mask is not None else None,
  1089. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
  1090. None,
  1091. output_attentions,
  1092. use_cache,
  1093. )
  1094. else:
  1095. layer_outputs = decoder_layer(
  1096. hidden_states,
  1097. attention_mask=attention_mask,
  1098. encoder_hidden_states=encoder_hidden_states,
  1099. encoder_attention_mask=encoder_attention_mask,
  1100. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1101. cross_attn_layer_head_mask=(
  1102. cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
  1103. ),
  1104. past_key_value=past_key_value,
  1105. output_attentions=output_attentions,
  1106. use_cache=use_cache,
  1107. )
  1108. hidden_states = layer_outputs[0]
  1109. if use_cache:
  1110. present_key_value_states += (layer_outputs[3 if output_attentions else 1],)
  1111. if output_attentions:
  1112. all_self_attns += (layer_outputs[1],)
  1113. if encoder_hidden_states is not None:
  1114. all_cross_attentions += (layer_outputs[2],)
  1115. # add final layer norm
  1116. hidden_states = self.layer_norm(hidden_states)
  1117. # add hidden states from the last decoder layer
  1118. if output_hidden_states:
  1119. all_hidden_states += (hidden_states,)
  1120. if not return_dict:
  1121. return tuple(
  1122. v
  1123. for v in [
  1124. hidden_states,
  1125. present_key_value_states,
  1126. all_hidden_states,
  1127. all_self_attns,
  1128. all_cross_attentions,
  1129. ]
  1130. if v is not None
  1131. )
  1132. return BaseModelOutputWithPastAndCrossAttentions(
  1133. last_hidden_state=hidden_states,
  1134. past_key_values=present_key_value_states,
  1135. hidden_states=all_hidden_states,
  1136. attentions=all_self_attns,
  1137. cross_attentions=all_cross_attentions,
  1138. )
  1139. class Kosmos2PreTrainedModel(PreTrainedModel):
  1140. """
  1141. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1142. models.
  1143. """
  1144. config_class = Kosmos2Config
  1145. supports_gradient_checkpointing = True
  1146. _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
  1147. def _init_weights(self, module):
  1148. """Initialize the weights"""
  1149. if isinstance(self, Kosmos2VisionModel):
  1150. factor = self.config.initializer_factor
  1151. elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)):
  1152. factor = self.config.vision_config.initializer_factor
  1153. if isinstance(self, (Kosmos2TextModel, Kosmos2TextForCausalLM)):
  1154. std = self.config.init_std
  1155. elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)):
  1156. std = self.config.text_config.init_std
  1157. if isinstance(module, Kosmos2VisionEmbeddings):
  1158. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  1159. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  1160. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  1161. elif isinstance(module, Kosmos2VisionAttention):
  1162. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  1163. out_proj_std = (module.embed_dim**-0.5) * factor
  1164. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  1165. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  1166. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  1167. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  1168. if module.q_proj.bias is not None:
  1169. module.q_proj.bias.data.zero_()
  1170. if module.k_proj.bias is not None:
  1171. module.k_proj.bias.data.zero_()
  1172. if module.v_proj.bias is not None:
  1173. module.v_proj.bias.data.zero_()
  1174. if module.out_proj.bias is not None:
  1175. module.out_proj.bias.data.zero_()
  1176. elif isinstance(module, Kosmos2VisionMLP):
  1177. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  1178. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  1179. nn.init.normal_(module.fc1.weight, std=fc_std)
  1180. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  1181. if module.fc1.bias is not None:
  1182. module.fc1.bias.data.zero_()
  1183. if module.fc2.bias is not None:
  1184. module.fc2.bias.data.zero_()
  1185. elif isinstance(module, Kosmos2VisionEncoderLayer):
  1186. module.layer_norm1.bias.data.zero_()
  1187. module.layer_norm1.weight.data.fill_(1.0)
  1188. module.layer_norm2.bias.data.zero_()
  1189. module.layer_norm2.weight.data.fill_(1.0)
  1190. elif isinstance(module, Kosmos2VisionTransformer):
  1191. module.pre_layrnorm.bias.data.zero_()
  1192. module.pre_layrnorm.weight.data.fill_(1.0)
  1193. module.post_layernorm.bias.data.zero_()
  1194. module.post_layernorm.weight.data.fill_(1.0)
  1195. elif isinstance(module, KosmosTextAttention):
  1196. nn.init.normal_(module.q_proj.weight, std=std)
  1197. nn.init.normal_(module.k_proj.weight, std=std)
  1198. nn.init.normal_(module.v_proj.weight, std=std)
  1199. nn.init.normal_(module.out_proj.weight, std=std)
  1200. if module.q_proj.bias is not None:
  1201. module.q_proj.bias.data.zero_()
  1202. if module.k_proj.bias is not None:
  1203. module.k_proj.bias.data.zero_()
  1204. if module.v_proj.bias is not None:
  1205. module.v_proj.bias.data.zero_()
  1206. if module.out_proj.bias is not None:
  1207. module.out_proj.bias.data.zero_()
  1208. elif isinstance(module, Kosmos2TextFFN):
  1209. nn.init.normal_(module.fc1.weight, std=std)
  1210. nn.init.normal_(module.fc2.weight, std=std)
  1211. if module.fc1.bias is not None:
  1212. module.fc1.bias.data.zero_()
  1213. if module.fc2.bias is not None:
  1214. module.fc2.bias.data.zero_()
  1215. elif isinstance(module, Kosmos2TextForCausalLM):
  1216. nn.init.normal_(module.lm_head.weight, std=std)
  1217. if module.lm_head.bias is not None:
  1218. module.lm_head.bias.data.zero_()
  1219. elif isinstance(module, Kosmos2ImageToTextProjection):
  1220. nn.init.normal_(module.dense.weight, std=std)
  1221. if module.dense.bias is not None:
  1222. module.dense.bias.data.zero_()
  1223. elif isinstance(module, Kosmos2TextTransformer):
  1224. module.embed_tokens.weight.data.normal_(mean=0.0, std=std)
  1225. if module.embed_tokens.padding_idx is not None:
  1226. module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_()
  1227. class Kosmos2VisionModel(Kosmos2PreTrainedModel):
  1228. config_class = Kosmos2VisionConfig
  1229. main_input_name = "pixel_values"
  1230. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  1231. def __init__(self, config: Kosmos2VisionConfig):
  1232. super().__init__(config)
  1233. self.model = Kosmos2VisionTransformer(config)
  1234. # Initialize weights and apply final processing
  1235. self.post_init()
  1236. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  1237. def get_input_embeddings(self) -> nn.Module:
  1238. return self.model.embeddings.patch_embedding
  1239. @add_start_docstrings_to_model_forward(KOSMOS2_VISION_INPUTS_DOCSTRING)
  1240. @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Kosmos2VisionConfig)
  1241. def forward(
  1242. self,
  1243. pixel_values: Optional[torch.FloatTensor] = None,
  1244. output_attentions: Optional[bool] = None,
  1245. output_hidden_states: Optional[bool] = None,
  1246. interpolate_pos_encoding: bool = False,
  1247. return_dict: Optional[bool] = None,
  1248. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  1249. r"""
  1250. Returns:
  1251. """
  1252. return self.model(
  1253. pixel_values=pixel_values,
  1254. output_attentions=output_attentions,
  1255. output_hidden_states=output_hidden_states,
  1256. interpolate_pos_encoding=interpolate_pos_encoding,
  1257. return_dict=return_dict,
  1258. )
  1259. class Kosmos2TextModel(Kosmos2PreTrainedModel):
  1260. config_class = Kosmos2TextConfig
  1261. def __init__(self, config: Kosmos2TextConfig):
  1262. super().__init__(config)
  1263. self.model = Kosmos2TextTransformer(config)
  1264. # Initialize weights and apply final processing
  1265. self.post_init()
  1266. def get_input_embeddings(self) -> nn.Module:
  1267. return self.model.embed_tokens
  1268. def set_input_embeddings(self, value):
  1269. self.model.embed_tokens = value
  1270. @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
  1271. @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=Kosmos2TextConfig)
  1272. def forward(
  1273. self,
  1274. input_ids: Optional[torch.Tensor] = None,
  1275. attention_mask: Optional[torch.Tensor] = None,
  1276. image_embeds: Optional[torch.Tensor] = None,
  1277. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1278. encoder_hidden_states: Optional[torch.Tensor] = None,
  1279. encoder_attention_mask: Optional[torch.Tensor] = None,
  1280. head_mask: Optional[torch.Tensor] = None,
  1281. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1282. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1283. inputs_embeds: Optional[torch.Tensor] = None,
  1284. position_ids: Optional[torch.Tensor] = None,
  1285. use_cache: Optional[bool] = None,
  1286. output_attentions: Optional[bool] = None,
  1287. output_hidden_states: Optional[bool] = None,
  1288. return_dict: Optional[bool] = None,
  1289. ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1290. r"""
  1291. Returns:
  1292. """
  1293. return self.model(
  1294. input_ids=input_ids,
  1295. attention_mask=attention_mask,
  1296. image_embeds=image_embeds,
  1297. image_embeds_position_mask=image_embeds_position_mask,
  1298. encoder_hidden_states=encoder_hidden_states,
  1299. encoder_attention_mask=encoder_attention_mask,
  1300. head_mask=head_mask,
  1301. cross_attn_head_mask=cross_attn_head_mask,
  1302. past_key_values=past_key_values,
  1303. inputs_embeds=inputs_embeds,
  1304. position_ids=position_ids,
  1305. use_cache=use_cache,
  1306. output_attentions=output_attentions,
  1307. output_hidden_states=output_hidden_states,
  1308. return_dict=return_dict,
  1309. )
  1310. @add_start_docstrings(
  1311. """
  1312. The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
  1313. embeddings).
  1314. """,
  1315. KOSMOS2_START_DOCSTRING,
  1316. )
  1317. class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
  1318. config_class = Kosmos2TextConfig
  1319. _tied_weights_keys = ["lm_head.weight"]
  1320. def __init__(self, config: Kosmos2TextConfig):
  1321. super().__init__(config)
  1322. self.model = Kosmos2TextTransformer(config)
  1323. self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
  1324. # Initialize weights and apply final processing
  1325. self.post_init()
  1326. def get_input_embeddings(self) -> nn.Module:
  1327. return self.model.embed_tokens
  1328. def set_input_embeddings(self, value):
  1329. self.model.embed_tokens = value
  1330. def get_output_embeddings(self) -> nn.Module:
  1331. return self.lm_head
  1332. def set_output_embeddings(self, new_embeddings):
  1333. self.lm_head = new_embeddings
  1334. @add_start_docstrings_to_model_forward(KOSMOS2_TEXT_INPUTS_DOCSTRING)
  1335. @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=Kosmos2TextConfig)
  1336. def forward(
  1337. self,
  1338. input_ids: Optional[torch.Tensor] = None,
  1339. attention_mask: Optional[torch.Tensor] = None,
  1340. image_embeds: Optional[torch.Tensor] = None,
  1341. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1342. encoder_hidden_states: Optional[torch.Tensor] = None,
  1343. encoder_attention_mask: Optional[torch.Tensor] = None,
  1344. head_mask: Optional[torch.Tensor] = None,
  1345. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1346. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1347. inputs_embeds: Optional[torch.Tensor] = None,
  1348. position_ids: Optional[torch.Tensor] = None,
  1349. labels: Optional[torch.LongTensor] = None,
  1350. use_cache: Optional[bool] = None,
  1351. output_attentions: Optional[bool] = None,
  1352. output_hidden_states: Optional[bool] = None,
  1353. return_dict: Optional[bool] = None,
  1354. ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
  1355. r"""
  1356. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1357. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1358. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1359. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1360. Returns:
  1361. """
  1362. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1363. if labels is not None:
  1364. if use_cache:
  1365. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1366. use_cache = False
  1367. outputs = self.model(
  1368. input_ids=input_ids,
  1369. attention_mask=attention_mask,
  1370. image_embeds=image_embeds,
  1371. image_embeds_position_mask=image_embeds_position_mask,
  1372. encoder_hidden_states=encoder_hidden_states,
  1373. encoder_attention_mask=encoder_attention_mask,
  1374. head_mask=head_mask,
  1375. cross_attn_head_mask=cross_attn_head_mask,
  1376. past_key_values=past_key_values,
  1377. inputs_embeds=inputs_embeds,
  1378. position_ids=position_ids,
  1379. use_cache=use_cache,
  1380. output_attentions=output_attentions,
  1381. output_hidden_states=output_hidden_states,
  1382. return_dict=return_dict,
  1383. )
  1384. lm_logits = self.lm_head(outputs[0])
  1385. loss = None
  1386. if labels is not None:
  1387. # move labels to correct device to enable model parallelism
  1388. labels = labels.to(lm_logits.device)
  1389. # Shift so that tokens < n predict n
  1390. shift_logits = lm_logits[..., :-1, :].contiguous()
  1391. shift_labels = labels[..., 1:].contiguous()
  1392. batch_size, seq_length, vocab_size = shift_logits.shape
  1393. # Flatten the tokens
  1394. loss_fct = CrossEntropyLoss()
  1395. loss = loss_fct(
  1396. shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
  1397. )
  1398. if not return_dict:
  1399. output = (lm_logits,) + outputs[1:]
  1400. return (loss,) + output if loss is not None else output
  1401. return CausalLMOutputWithCrossAttentions(
  1402. loss=loss,
  1403. logits=lm_logits,
  1404. past_key_values=outputs.past_key_values,
  1405. hidden_states=outputs.hidden_states,
  1406. attentions=outputs.attentions,
  1407. cross_attentions=outputs.cross_attentions,
  1408. )
  1409. def prepare_inputs_for_generation(
  1410. self,
  1411. input_ids,
  1412. image_embeds=None,
  1413. image_embeds_position_mask=None,
  1414. past_key_values=None,
  1415. attention_mask=None,
  1416. use_cache=None,
  1417. **model_kwargs,
  1418. ):
  1419. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  1420. input_shape = input_ids.shape
  1421. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  1422. if attention_mask is None:
  1423. attention_mask = input_ids.new_ones(input_shape)
  1424. position_ids = None
  1425. # cut input_ids if past_key_values is used
  1426. if past_key_values is not None:
  1427. position_ids = create_position_ids_from_input_ids(
  1428. input_ids,
  1429. padding_idx=self.config.pad_token_id,
  1430. past_key_values_length=0,
  1431. )[:, -1:]
  1432. input_ids = input_ids[:, -1:]
  1433. # the image info. is already encoded into the past keys/values
  1434. image_embeds = None
  1435. image_embeds_position_mask = None
  1436. elif image_embeds_position_mask is not None:
  1437. # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
  1438. batch_size, seq_len = input_ids.size()
  1439. mask_len = image_embeds_position_mask.size()[-1]
  1440. image_embeds_position_mask = torch.cat(
  1441. (
  1442. image_embeds_position_mask,
  1443. torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),
  1444. ),
  1445. dim=1,
  1446. )
  1447. return {
  1448. "input_ids": input_ids,
  1449. "image_embeds": image_embeds,
  1450. "image_embeds_position_mask": image_embeds_position_mask,
  1451. "past_key_values": past_key_values,
  1452. "attention_mask": attention_mask,
  1453. "position_ids": position_ids,
  1454. "use_cache": use_cache,
  1455. }
  1456. @staticmethod
  1457. # Copied from transformers.models.umt5.modeling_umt5.UMT5ForConditionalGeneration._reorder_cache
  1458. def _reorder_cache(past_key_values, beam_idx):
  1459. reordered_past = ()
  1460. for layer_past in past_key_values:
  1461. reordered_past += (
  1462. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1463. )
  1464. return reordered_past
  1465. class Kosmos2ImageToTextProjection(nn.Module):
  1466. """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""
  1467. def __init__(self, config: Kosmos2Config):
  1468. super().__init__()
  1469. self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
  1470. self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))
  1471. self.x_attn = KosmosTextAttention(
  1472. config.text_config,
  1473. config.text_config.embed_dim,
  1474. config.text_config.attention_heads,
  1475. dropout=config.text_config.attention_dropout,
  1476. is_decoder=False,
  1477. add_inner_attn_layernorm=False,
  1478. )
  1479. def forward(self, features):
  1480. hidden_states = self.dense(features)
  1481. # shape = [batch, latent_query_num, h_dim]
  1482. latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
  1483. key_value_states = torch.cat([hidden_states, latent_query], dim=1)
  1484. hidden_states, attn_weights, _ = self.x_attn(
  1485. hidden_states=latent_query,
  1486. encoder_hidden_states=key_value_states,
  1487. past_key_value=None,
  1488. attention_mask=None,
  1489. output_attentions=None,
  1490. )
  1491. return hidden_states, attn_weights
  1492. @add_start_docstrings(
  1493. """
  1494. KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
  1495. """,
  1496. KOSMOS2_START_DOCSTRING,
  1497. )
  1498. class Kosmos2Model(Kosmos2PreTrainedModel):
  1499. config_class = Kosmos2Config
  1500. main_input_name = "pixel_values"
  1501. def __init__(self, config: Kosmos2Config):
  1502. super().__init__(config)
  1503. self.text_model = Kosmos2TextModel(config.text_config)
  1504. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1505. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1506. # Initialize weights and apply final processing
  1507. self.post_init()
  1508. def get_input_embeddings(self) -> nn.Module:
  1509. return self.text_model.model.embed_tokens
  1510. def set_input_embeddings(self, value):
  1511. self.text_model.model.embed_tokens = value
  1512. @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
  1513. @replace_return_docstrings(output_type=Kosmos2ModelOutput, config_class=_CONFIG_FOR_DOC)
  1514. def forward(
  1515. self,
  1516. pixel_values: Optional[torch.Tensor] = None,
  1517. input_ids: Optional[torch.Tensor] = None,
  1518. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1519. attention_mask: Optional[torch.Tensor] = None,
  1520. head_mask: Optional[torch.Tensor] = None,
  1521. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1522. image_embeds: Optional[torch.Tensor] = None,
  1523. inputs_embeds: Optional[torch.Tensor] = None,
  1524. position_ids: Optional[torch.Tensor] = None,
  1525. use_cache: Optional[bool] = None,
  1526. output_attentions: Optional[bool] = None,
  1527. output_hidden_states: Optional[bool] = None,
  1528. interpolate_pos_encoding: bool = False,
  1529. return_dict: Optional[bool] = None,
  1530. ) -> Union[Tuple, Kosmos2ModelOutput]:
  1531. r"""
  1532. Returns:
  1533. Examples:
  1534. ```python
  1535. >>> from PIL import Image
  1536. >>> import requests
  1537. >>> from transformers import AutoProcessor, Kosmos2Model
  1538. >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224")
  1539. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1540. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1541. >>> image = Image.open(requests.get(url, stream=True).raw)
  1542. >>> text = (
  1543. ... "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863>"
  1544. ... "</object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911>"
  1545. ... "</object>"
  1546. ... )
  1547. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True)
  1548. >>> last_hidden_state = model(
  1549. ... pixel_values=inputs["pixel_values"],
  1550. ... input_ids=inputs["input_ids"],
  1551. ... attention_mask=inputs["attention_mask"],
  1552. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1553. ... ).last_hidden_state
  1554. >>> list(last_hidden_state.shape)
  1555. [1, 91, 2048]
  1556. ```"""
  1557. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1558. output_hidden_states = (
  1559. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1560. )
  1561. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1562. vision_model_output = None
  1563. projection_attentions = None
  1564. if image_embeds is None:
  1565. if pixel_values is None:
  1566. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1567. vision_model_output = self.vision_model(
  1568. pixel_values=pixel_values,
  1569. output_attentions=output_attentions,
  1570. output_hidden_states=output_hidden_states,
  1571. interpolate_pos_encoding=interpolate_pos_encoding,
  1572. return_dict=return_dict,
  1573. )
  1574. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1575. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1576. # normalized features
  1577. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1578. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1579. outputs = self.text_model(
  1580. input_ids=input_ids,
  1581. attention_mask=attention_mask,
  1582. image_embeds=image_embeds,
  1583. image_embeds_position_mask=image_embeds_position_mask,
  1584. head_mask=head_mask,
  1585. past_key_values=past_key_values,
  1586. inputs_embeds=inputs_embeds,
  1587. position_ids=position_ids,
  1588. use_cache=use_cache,
  1589. output_attentions=output_attentions,
  1590. output_hidden_states=output_hidden_states,
  1591. return_dict=return_dict,
  1592. )
  1593. if not return_dict:
  1594. outputs = outputs + (image_embeds, projection_attentions, vision_model_output)
  1595. return tuple(output for output in outputs if output is not None)
  1596. return Kosmos2ModelOutput(
  1597. last_hidden_state=outputs.last_hidden_state,
  1598. past_key_values=outputs.past_key_values,
  1599. hidden_states=outputs.hidden_states,
  1600. attentions=outputs.attentions,
  1601. image_embeds=image_embeds,
  1602. projection_attentions=projection_attentions,
  1603. vision_model_output=vision_model_output,
  1604. )
  1605. @add_start_docstrings(
  1606. """
  1607. KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
  1608. language model.
  1609. """,
  1610. KOSMOS2_START_DOCSTRING,
  1611. )
  1612. class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
  1613. config_class = Kosmos2Config
  1614. main_input_name = "pixel_values"
  1615. _tied_weights_keys = ["text_model.lm_head.weight"]
  1616. def __init__(self, config: Kosmos2Config):
  1617. super().__init__(config)
  1618. self.text_model = Kosmos2TextForCausalLM(config.text_config)
  1619. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1620. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1621. # Initialize weights and apply final processing
  1622. self.post_init()
  1623. def get_input_embeddings(self) -> nn.Module:
  1624. return self.text_model.model.embed_tokens
  1625. def set_input_embeddings(self, value):
  1626. self.text_model.model.embed_tokens = value
  1627. def get_output_embeddings(self) -> nn.Module:
  1628. return self.text_model.get_output_embeddings()
  1629. def set_output_embeddings(self, new_embeddings):
  1630. self.text_model.set_output_embeddings(new_embeddings)
  1631. @add_start_docstrings_to_model_forward(KOSMOS2_INPUTS_DOCSTRING)
  1632. @replace_return_docstrings(output_type=Kosmos2ForConditionalGenerationModelOutput, config_class=_CONFIG_FOR_DOC)
  1633. def forward(
  1634. self,
  1635. pixel_values: Optional[torch.Tensor] = None,
  1636. input_ids: Optional[torch.Tensor] = None,
  1637. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1638. attention_mask: Optional[torch.Tensor] = None,
  1639. head_mask: Optional[torch.Tensor] = None,
  1640. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1641. image_embeds: Optional[torch.Tensor] = None,
  1642. inputs_embeds: Optional[torch.Tensor] = None,
  1643. position_ids: Optional[torch.Tensor] = None,
  1644. labels: Optional[torch.LongTensor] = None,
  1645. use_cache: Optional[bool] = None,
  1646. output_attentions: Optional[bool] = None,
  1647. output_hidden_states: Optional[bool] = None,
  1648. return_dict: Optional[bool] = None,
  1649. ) -> Union[Tuple, Kosmos2ForConditionalGenerationModelOutput]:
  1650. r"""
  1651. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1652. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1653. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1654. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1655. Returns:
  1656. Examples:
  1657. ```python
  1658. >>> from PIL import Image
  1659. >>> import requests
  1660. >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
  1661. >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224")
  1662. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1663. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1664. >>> image = Image.open(requests.get(url, stream=True).raw)
  1665. >>> prompt = "<grounding> An image of"
  1666. >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
  1667. >>> generated_ids = model.generate(
  1668. ... pixel_values=inputs["pixel_values"],
  1669. ... input_ids=inputs["input_ids"],
  1670. ... attention_mask=inputs["attention_mask"],
  1671. ... image_embeds=None,
  1672. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1673. ... use_cache=True,
  1674. ... max_new_tokens=64,
  1675. ... )
  1676. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1677. >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
  1678. >>> processed_text
  1679. '<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.'
  1680. >>> caption, entities = processor.post_process_generation(generated_text)
  1681. >>> caption
  1682. 'An image of a snowman warming himself by a fire.'
  1683. >>> entities
  1684. [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
  1685. ```"""
  1686. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1687. output_hidden_states = (
  1688. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1689. )
  1690. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1691. vision_model_output = None
  1692. projection_attentions = None
  1693. if image_embeds is None:
  1694. if pixel_values is None:
  1695. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1696. vision_model_output = self.vision_model(
  1697. pixel_values=pixel_values,
  1698. output_attentions=output_attentions,
  1699. output_hidden_states=output_hidden_states,
  1700. return_dict=return_dict,
  1701. )
  1702. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1703. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1704. # normalized features
  1705. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1706. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1707. lm_outputs = self.text_model(
  1708. input_ids=input_ids,
  1709. attention_mask=attention_mask,
  1710. image_embeds=image_embeds,
  1711. image_embeds_position_mask=image_embeds_position_mask,
  1712. head_mask=head_mask,
  1713. past_key_values=past_key_values,
  1714. inputs_embeds=inputs_embeds,
  1715. position_ids=position_ids,
  1716. labels=labels,
  1717. use_cache=use_cache,
  1718. output_attentions=output_attentions,
  1719. output_hidden_states=output_hidden_states,
  1720. return_dict=return_dict,
  1721. )
  1722. if not return_dict:
  1723. outputs = lm_outputs + (image_embeds, projection_attentions, vision_model_output)
  1724. return tuple(output for output in outputs if output is not None)
  1725. return Kosmos2ForConditionalGenerationModelOutput(
  1726. loss=lm_outputs.loss,
  1727. logits=lm_outputs.logits,
  1728. past_key_values=lm_outputs.past_key_values,
  1729. hidden_states=lm_outputs.hidden_states,
  1730. attentions=lm_outputs.attentions,
  1731. image_embeds=image_embeds,
  1732. projection_attentions=projection_attentions,
  1733. vision_model_output=vision_model_output,
  1734. )
  1735. def generate(
  1736. self,
  1737. pixel_values: Optional[torch.Tensor] = None,
  1738. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1739. input_ids: Optional[torch.Tensor] = None,
  1740. attention_mask: Optional[torch.Tensor] = None,
  1741. image_embeds: Optional[torch.Tensor] = None,
  1742. **kwargs,
  1743. ):
  1744. # in order to allow `inputs` argument (as in `GenerationMixin`)
  1745. inputs = kwargs.pop("inputs", None)
  1746. if pixel_values is not None and inputs is not None:
  1747. raise ValueError(
  1748. f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
  1749. f"Make sure to either pass `inputs` or pixel_values=..."
  1750. )
  1751. if pixel_values is None and inputs is not None:
  1752. pixel_values = inputs
  1753. if image_embeds is None:
  1754. vision_model_output = self.vision_model(pixel_values)
  1755. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1756. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1757. # normalized features
  1758. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1759. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1760. output = self.text_model.generate(
  1761. input_ids=input_ids,
  1762. attention_mask=attention_mask,
  1763. image_embeds=image_embeds,
  1764. image_embeds_position_mask=image_embeds_position_mask,
  1765. **kwargs,
  1766. )
  1767. return output