modeling_flava.py 94 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch FLAVA model."""
  16. import collections
  17. import math
  18. from collections import OrderedDict
  19. from dataclasses import dataclass
  20. from typing import Any, Dict, List, Optional, Set, Tuple, Union
  21. import torch
  22. import torch.utils.checkpoint
  23. from torch import nn
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
  26. from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
  27. from ...utils import (
  28. ModelOutput,
  29. add_code_sample_docstrings,
  30. add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. logging,
  33. replace_return_docstrings,
  34. torch_int,
  35. )
  36. from .configuration_flava import (
  37. FlavaConfig,
  38. FlavaImageCodebookConfig,
  39. FlavaImageConfig,
  40. FlavaMultimodalConfig,
  41. FlavaTextConfig,
  42. )
  43. logger = logging.get_logger(__name__)
  44. _CHECKPOINT_FOR_DOC = "facebook/flava-full"
  45. # Codebook docstring
  46. _CHECKPOINT_FOR_CODEBOOK_DOC = "facebook/flava-image-codebook"
  47. _CONFIG_CLASS_FOR_IMAGE_MODEL_DOC = "FlavaImageConfig"
  48. _CONFIG_CLASS_FOR_TEXT_MODEL_DOC = "FlavaTextConfig"
  49. _CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC = "FlavaMultimodalConfig"
  50. _EXPECTED_IMAGE_OUTPUT_SHAPE = [1, 197, 768]
  51. LOGIT_SCALE_CLAMP_MIN = 0
  52. LOGIT_SCALE_CLAMP_MAX = 4.6052
  53. FlavaPossibleConfigs = Union[FlavaTextConfig, FlavaImageConfig, FlavaMultimodalConfig]
  54. @dataclass
  55. class FlavaModelOutput(ModelOutput):
  56. """
  57. Output from FlavaModel containing embeddings and outputs from individual encoders.
  58. Note that `image_embeddings` and `text_embeddigns` returned are similar to pooled output returned from a
  59. transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
  60. `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
  61. Args:
  62. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  63. The image embeddings which are basically the pooled output of [`FlavaImageModel`].
  64. image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  65. The output of the [`FlavaImageModel`].
  66. text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
  67. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  68. text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
  69. The output of the [`FlavaTextModel`].
  70. multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
  71. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  72. multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_multimodal_encoder` is `None` or `False`):
  73. The output of the [`FlavaMultimodalModel`].
  74. """
  75. image_embeddings: Optional[torch.FloatTensor] = None
  76. image_output: Optional[BaseModelOutputWithPooling] = None
  77. text_embeddings: Optional[torch.FloatTensor] = None
  78. text_output: Optional[BaseModelOutputWithPooling] = None
  79. multimodal_embeddings: Optional[torch.FloatTensor] = None
  80. multimodal_output: Optional[BaseModelOutputWithPooling] = None
  81. def to_tuple(self) -> Tuple[Any]:
  82. return tuple(
  83. self[k] if k not in ["text_output", "image_output", "multimodal_output"] else getattr(self, k).to_tuple()
  84. for k in self.keys()
  85. )
  86. @dataclass
  87. class FlavaLosses(ModelOutput):
  88. """Class representing pretraining losses from FLAVA model
  89. Args:
  90. mim (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels` and `pixel_values` are present, `input_ids_masked` is absent and `mim_weight` > 0.:
  91. Masked Image Modeling loss as used in BeIT calculated only for unimodal image data.
  92. mlm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels` and `input_ids_masked` are present, `pixel_values` is absent and `mlm_weight` > 0.:
  93. Masked Language Modeling loss as used in BERT calculated only for unimodal text data.
  94. itm (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `itm_labels`, `input_ids_masked`, `pixel_values` are present and `itm_weight` > 0.:
  95. Image Text Matching (ITM) loss calculated for paired image-text data. Note that ITM loss is calculated on
  96. masked pairs in FLAVA.
  97. global_contrastive (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `input_ids` and `pixel_values` are present and `global_contrastive_weight` > 0.:
  98. Contrastive loss for image-text similarity similar to CLIP but calculated globally for paired image-text
  99. data. This is calculated on unmasked images and texts.
  100. mmm_image (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mim_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_image_weight` > 0.:
  101. Masked Multimodal Modeling loss's image component calculated on paired image-text data.
  102. mmm_text (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mlm_labels`, `pixel_values` and `input_ids_masked` are present and `mmm_text_weight` > 0.:
  103. Masked Multimodal Modeling loss's text component calculated on paired image-text data.
  104. """
  105. mim: Optional[torch.FloatTensor] = None
  106. mlm: Optional[torch.FloatTensor] = None
  107. itm: Optional[torch.FloatTensor] = None
  108. global_contrastive: Optional[torch.FloatTensor] = None
  109. mmm_image: Optional[torch.FloatTensor] = None
  110. mmm_text: Optional[torch.FloatTensor] = None
  111. def all_none(self) -> bool:
  112. all_none = True
  113. for v in self.values():
  114. if v is not None:
  115. all_none = False
  116. break
  117. return all_none
  118. @dataclass
  119. class FlavaForPreTrainingOutput(ModelOutput):
  120. """
  121. Output from FlavaForPreTraining containing embeddings, and outputs from individual encoders.
  122. Note that `image_embeddings` and `text_embeddings` returned are similar to pooled output returned from a
  123. transformer. If you want embeddings for contrastive loss or retrieval use a FLAVA model's `image_projection` and
  124. `text_projection` layers on `image_embeddings` and `text_embeddings` respectively.
  125. Args:
  126. loss (`torch.FloatTensor`, *optional*, returned when `return_loss` is True):
  127. Total loss calculated for this model.
  128. loss_info (`FlavaLosses`):
  129. Detailed info for FLAVA Pretraining losses. Check `FlavaLosses` class description for the information on
  130. the keys.
  131. image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  132. The image embeddings which are basically the pooled output of [`FlavaImageModel`].
  133. image_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  134. The output of the [`FlavaImageModel`].
  135. text_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` are present):
  136. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  137. text_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids` are present):
  138. The output of the [`FlavaTextModel`].
  139. multimodal_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
  140. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  141. multimodal_output (`BaseModelOutputWithPooling`, returned when `input_ids` and `pixel_values` are present and `skip_unmasked_multimodal_encoder` is `None` or `False`):
  142. The output of the [`FlavaMultimodalModel`].
  143. image_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `pixel_values` are present):
  144. The image embeddings which are basically the pooled output of [`FlavaImageModel`]. Uses `bool_masked_pos`
  145. to create masked images.
  146. image_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `pixel_values` are present):
  147. The output of the [`FlavaImageModel`]. Uses `bool_masked_pos` to create masked images.
  148. text_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids_masked` are present):
  149. The text embeddings which are basically the pooled output of [`FlavaTextModel`].
  150. text_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` are present):
  151. The output of the [`FlavaTextModel`].
  152. multimodal_masked_embeddings (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*, returned when `input_ids` and `pixel_values` are present):
  153. The multimodal embeddings which are basically the pooled output of [`FlavaTextModel`].
  154. multimodal_masked_output (`BaseModelOutputWithPooling`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
  155. The output of the [`FlavaMultimodalModel`].
  156. mim_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape `(total_masked_patches, image_vocab_size)` , *optional*, returned when `pixel_values` are present and `input_ids_masked` are not):
  157. The logits for MIM unimodal loss. Uses `book_masked_pos` to get masked patches. The flattened output is
  158. returned when `bool_masked_pos` has some of the patches masked.
  159. mlm_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(total_masked_seq_length, text_vocab_size)`, *optional*, returned when `input_ids_masked` are present and `pixel_values` are not):
  160. The logits for MLM unimodal loss. The flattened output is returned when `input_ids_masked` has some of
  161. the tokens masked.
  162. itm_logits (`torch.FloatTensor` of shape `(batch_size, 2)`, *optional*, returned when `input_ids_masked` and `pixel_values` are present):
  163. The logits for ITM loss. Note that ITM loss is calculated on masked pairs in FLAVA.
  164. mmm_image_logits (`torch.FloatTensor` of shape `(batch_size, num_image_patches, image_vocab_size)` or of shape`(total_masked_patches, image_vocab_size)`, *optional*, returned when `pixel_values` and `input_ids_masked` are present):
  165. The logits for MMM image multimodal loss. Uses `book_masked_pos` to get masked patches. The flattened
  166. output is returned when `bool_masked_pos` has some of the patches masked.
  167. mmm_text_logits (`torch.FloatTensor` of shape `(batch_size, text_seq_length, text_vocab_size)` or of shape `(`(total_masked_seq_length, text_vocab_size)`), *optional*, returned when `pixel_values` and `input_ids_masked` are present):
  168. The logits for MMM text multimodal loss. The flattened output is returned when `input_ids_masked` has
  169. some of the tokens masked.
  170. contrastive_logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  171. The scaled dot product scores between `image_embeddings` and `text_embeddings` but passed through FLAVA's
  172. `image_projection` and `text_projection` layers respectively. This represents the image-text similarity
  173. scores. This is calculated on unmasked images and texts.
  174. contrastive_logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  175. The scaled dot product scores between `text_embeddings` and `image_embeddings` but passed through FLAVA's
  176. `text_projection` and `image_projection` layers respectively. This is calculated on unmasked images and
  177. texts.
  178. """
  179. loss: Optional[torch.FloatTensor] = None
  180. loss_info: FlavaLosses = None
  181. image_embeddings: Optional[torch.FloatTensor] = None
  182. image_output: Optional[BaseModelOutputWithPooling] = None
  183. text_embeddings: Optional[torch.FloatTensor] = None
  184. text_output: Optional[BaseModelOutputWithPooling] = None
  185. multimodal_embeddings: Optional[torch.FloatTensor] = None
  186. multimodal_output: Optional[BaseModelOutputWithPooling] = None
  187. image_masked_embeddings: Optional[torch.FloatTensor] = None
  188. image_masked_output: Optional[BaseModelOutputWithPooling] = None
  189. text_masked_embeddings: Optional[torch.FloatTensor] = None
  190. text_masked_output: Optional[BaseModelOutputWithPooling] = None
  191. multimodal_masked_embeddings: Optional[torch.FloatTensor] = None
  192. multimodal_masked_output: Optional[BaseModelOutputWithPooling] = None
  193. mim_logits: Optional[torch.FloatTensor] = None
  194. mlm_logits: Optional[torch.FloatTensor] = None
  195. itm_logits: Optional[torch.FloatTensor] = None
  196. contrastive_logits_per_image: Optional[torch.FloatTensor] = None
  197. contrastive_logits_per_text: Optional[torch.FloatTensor] = None
  198. mmm_image_logits: Optional[torch.FloatTensor] = None
  199. mmm_text_logits: Optional[torch.FloatTensor] = None
  200. def to_tuple(self) -> Tuple[Any]:
  201. transformer_outputs = [
  202. "text_output",
  203. "image_output",
  204. "multimodal_output",
  205. "text_masked_output",
  206. "image_masked_output",
  207. "multimodal_masked_output",
  208. ]
  209. return tuple(self[k] if k not in transformer_outputs else getattr(self, k).to_tuple() for k in self.keys())
  210. # Based on timm implementation, which can be found here:
  211. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
  212. class FlavaImageEmbeddings(nn.Module):
  213. """
  214. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  215. """
  216. def __init__(self, config: FlavaImageConfig, use_mask_token: bool = False) -> None:
  217. super().__init__()
  218. use_mask_token = use_mask_token or config.mask_token
  219. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  220. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  221. self.patch_embeddings = PatchEmbeddings(
  222. image_size=config.image_size,
  223. patch_size=config.patch_size,
  224. num_channels=config.num_channels,
  225. embed_dim=config.hidden_size,
  226. )
  227. num_patches = self.patch_embeddings.num_patches
  228. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  229. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  230. self.patch_size = config.patch_size
  231. self.config = config
  232. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  233. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  234. """
  235. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  236. images. This method is also adapted to support torch.jit tracing.
  237. Adapted from:
  238. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  239. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  240. """
  241. num_patches = embeddings.shape[1] - 1
  242. num_positions = self.position_embeddings.shape[1] - 1
  243. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  244. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  245. return self.position_embeddings
  246. class_pos_embed = self.position_embeddings[:, :1]
  247. patch_pos_embed = self.position_embeddings[:, 1:]
  248. dim = embeddings.shape[-1]
  249. new_height = height // self.patch_size
  250. new_width = width // self.patch_size
  251. sqrt_num_positions = torch_int(num_positions**0.5)
  252. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  253. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  254. patch_pos_embed = nn.functional.interpolate(
  255. patch_pos_embed,
  256. size=(new_height, new_width),
  257. mode="bicubic",
  258. align_corners=False,
  259. )
  260. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  261. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  262. def forward(
  263. self,
  264. pixel_values: torch.Tensor,
  265. bool_masked_pos: Optional[torch.BoolTensor] = None,
  266. interpolate_pos_encoding: bool = False,
  267. ) -> torch.Tensor:
  268. batch_size, num_channels, height, width = pixel_values.shape
  269. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  270. batch_size, seq_len, _ = embeddings.size()
  271. if bool_masked_pos is not None:
  272. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  273. # B X H X W = B X HW
  274. if bool_masked_pos.dim() == 3:
  275. bool_masked_pos = bool_masked_pos.view(bool_masked_pos.size(0), -1)
  276. # replace the masked visual tokens by mask_tokens
  277. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  278. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  279. # add the [CLS] token to the embedded patch tokens
  280. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  281. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  282. # add positional encoding to each token
  283. if interpolate_pos_encoding:
  284. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  285. else:
  286. embeddings = embeddings + self.position_embeddings
  287. embeddings = self.dropout(embeddings)
  288. return embeddings
  289. # Based on timm implementation, which can be found here:
  290. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/image_transformer.py
  291. class PatchEmbeddings(nn.Module):
  292. """
  293. Image to Patch Embedding.
  294. """
  295. def __init__(
  296. self,
  297. image_size: int = 224,
  298. patch_size: Union[int, Tuple[int, int]] = 16,
  299. num_channels: int = 3,
  300. embed_dim: int = 768,
  301. ):
  302. super().__init__()
  303. if not isinstance(image_size, collections.abc.Iterable):
  304. image_size = (image_size, image_size)
  305. if not isinstance(patch_size, collections.abc.Iterable):
  306. patch_size = (patch_size, patch_size)
  307. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  308. self.image_size = image_size
  309. self.patch_size = patch_size
  310. self.num_patches = num_patches
  311. self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
  312. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  313. batch_size, num_channels, height, width = pixel_values.shape
  314. if not interpolate_pos_encoding:
  315. if height != self.image_size[0] or width != self.image_size[1]:
  316. raise ValueError(
  317. f"Input image size ({height}*{width}) doesn't match model"
  318. f" ({self.image_size[0]}*{self.image_size[1]})."
  319. )
  320. x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  321. return x
  322. class FlavaTextEmbeddings(nn.Module):
  323. """Construct the embeddings from word, position and token_type embeddings."""
  324. def __init__(self, config):
  325. super().__init__()
  326. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  327. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  328. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  329. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  330. # any TensorFlow checkpoint file
  331. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  332. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  333. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  334. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  335. self.register_buffer(
  336. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  337. )
  338. self.register_buffer(
  339. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  340. )
  341. def forward(
  342. self,
  343. input_ids: Optional[torch.Tensor] = None,
  344. token_type_ids: Optional[torch.Tensor] = None,
  345. position_ids: Optional[torch.Tensor] = None,
  346. ):
  347. input_shape = input_ids.size()
  348. seq_length = input_shape[1]
  349. if position_ids is None:
  350. position_ids = self.position_ids[:, :seq_length]
  351. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  352. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  353. # issue #5664
  354. if token_type_ids is None:
  355. if hasattr(self, "token_type_ids"):
  356. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  357. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  358. token_type_ids = buffered_token_type_ids_expanded
  359. else:
  360. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  361. inputs_embeds = self.word_embeddings(input_ids)
  362. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  363. embeddings = inputs_embeds + token_type_embeddings
  364. if self.position_embedding_type == "absolute":
  365. position_embeddings = self.position_embeddings(position_ids)
  366. embeddings += position_embeddings
  367. embeddings = self.LayerNorm(embeddings)
  368. embeddings = self.dropout(embeddings)
  369. return embeddings
  370. class FlavaSelfAttention(nn.Module):
  371. def __init__(self, config: FlavaPossibleConfigs) -> None:
  372. super().__init__()
  373. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  374. raise ValueError(
  375. f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
  376. f"heads {config.num_attention_heads}."
  377. )
  378. self.num_attention_heads = config.num_attention_heads
  379. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  380. self.all_head_size = self.num_attention_heads * self.attention_head_size
  381. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  382. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  383. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  384. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  385. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  386. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  387. x = x.view(*new_x_shape)
  388. return x.permute(0, 2, 1, 3)
  389. def forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. attention_mask: Optional[torch.Tensor] = None,
  393. head_mask: Optional[torch.Tensor] = None,
  394. output_attentions: bool = False,
  395. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  396. mixed_query_layer = self.query(hidden_states)
  397. key_layer = self.transpose_for_scores(self.key(hidden_states))
  398. value_layer = self.transpose_for_scores(self.value(hidden_states))
  399. query_layer = self.transpose_for_scores(mixed_query_layer)
  400. # Take the dot product between "query" and "key" to get the raw attention scores.
  401. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  402. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  403. if attention_mask is not None:
  404. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  405. attention_scores = attention_scores + attention_mask
  406. # Normalize the attention scores to probabilities.
  407. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  408. # This is actually dropping out entire tokens to attend to, which might
  409. # seem a bit unusual, but is taken from the original Transformer paper.
  410. attention_probs = self.dropout(attention_probs)
  411. # Mask heads if we want to
  412. if head_mask is not None:
  413. attention_probs = attention_probs * head_mask
  414. context_layer = torch.matmul(attention_probs, value_layer)
  415. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  416. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  417. context_layer = context_layer.view(*new_context_layer_shape)
  418. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  419. return outputs
  420. class FlavaSelfOutput(nn.Module):
  421. """
  422. The residual connection is defined in FlavaLayer (same as ViTLayer) instead of here (as is the case with other
  423. models), due to the layernorm applied before each block.
  424. """
  425. def __init__(self, config: FlavaPossibleConfigs) -> None:
  426. super().__init__()
  427. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  428. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  429. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  430. hidden_states = self.dense(hidden_states)
  431. hidden_states = self.dropout(hidden_states)
  432. return hidden_states
  433. class FlavaAttention(nn.Module):
  434. def __init__(self, config: FlavaPossibleConfigs) -> None:
  435. super().__init__()
  436. self.attention = FlavaSelfAttention(config)
  437. self.output = FlavaSelfOutput(config)
  438. self.pruned_heads = set()
  439. def prune_heads(self, heads: Set[int]) -> None:
  440. if len(heads) == 0:
  441. return
  442. heads, index = find_pruneable_heads_and_indices(
  443. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  444. )
  445. # Prune linear layers
  446. self.attention.query = prune_linear_layer(self.attention.query, index)
  447. self.attention.key = prune_linear_layer(self.attention.key, index)
  448. self.attention.value = prune_linear_layer(self.attention.value, index)
  449. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  450. # Update hyper params and store pruned heads
  451. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  452. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  453. self.pruned_heads = self.pruned_heads.union(heads)
  454. def forward(
  455. self,
  456. hidden_states: torch.Tensor,
  457. attention_mask: Optional[torch.Tensor] = None,
  458. head_mask: Optional[torch.Tensor] = None,
  459. output_attentions: bool = False,
  460. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  461. self_outputs = self.attention(
  462. hidden_states, attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions
  463. )
  464. attention_output = self.output(self_outputs[0], hidden_states)
  465. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  466. return outputs
  467. class FlavaIntermediate(nn.Module):
  468. def __init__(self, config: FlavaPossibleConfigs) -> None:
  469. super().__init__()
  470. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  471. if isinstance(config.hidden_act, str):
  472. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  473. else:
  474. self.intermediate_act_fn = config.hidden_act
  475. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate.forward
  476. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  477. hidden_states = self.dense(hidden_states)
  478. hidden_states = self.intermediate_act_fn(hidden_states)
  479. return hidden_states
  480. class FlavaOutput(nn.Module):
  481. def __init__(self, config: FlavaPossibleConfigs) -> None:
  482. super().__init__()
  483. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  484. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  485. # Copied from transformers.models.vit.modeling_vit.ViTOutput.forward
  486. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  487. hidden_states = self.dense(hidden_states)
  488. hidden_states = self.dropout(hidden_states)
  489. hidden_states = hidden_states + input_tensor
  490. return hidden_states
  491. class FlavaLayer(nn.Module):
  492. """This corresponds to the Block class in the timm implementation."""
  493. def __init__(self, config: FlavaPossibleConfigs) -> None:
  494. super().__init__()
  495. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  496. self.seq_len_dim = 1
  497. self.attention = FlavaAttention(config)
  498. self.intermediate = FlavaIntermediate(config)
  499. self.output = FlavaOutput(config)
  500. # TODO: Check fp32 layer norm possiblity
  501. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  502. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  503. def forward(
  504. self,
  505. hidden_states: torch.Tensor,
  506. attention_mask: Optional[torch.Tensor] = None,
  507. head_mask: Optional[torch.Tensor] = None,
  508. output_attentions: bool = False,
  509. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  510. self_attention_outputs = self.attention(
  511. self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
  512. attention_mask=attention_mask,
  513. head_mask=head_mask,
  514. output_attentions=output_attentions,
  515. )
  516. attention_output = self_attention_outputs[0]
  517. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  518. # first residual connection
  519. hidden_states = attention_output + hidden_states
  520. # in ViT, layernorm is also applied after self-attention
  521. layer_output = self.layernorm_after(hidden_states)
  522. layer_output = self.intermediate(layer_output)
  523. # second residual connection is done here
  524. layer_output = self.output(layer_output, hidden_states)
  525. outputs = (layer_output,) + outputs
  526. return outputs
  527. class FlavaEncoder(nn.Module):
  528. def __init__(self, config: FlavaConfig) -> None:
  529. super().__init__()
  530. self.config = config
  531. self.layer = nn.ModuleList([FlavaLayer(config) for _ in range(config.num_hidden_layers)])
  532. self.gradient_checkpointing = False
  533. def forward(
  534. self,
  535. hidden_states: torch.Tensor,
  536. attention_mask: Optional[torch.Tensor] = None,
  537. head_mask: Optional[torch.Tensor] = None,
  538. output_attentions: bool = False,
  539. output_hidden_states: bool = False,
  540. return_dict: bool = True,
  541. ) -> Union[tuple, BaseModelOutput]:
  542. all_hidden_states = () if output_hidden_states else None
  543. all_self_attentions = () if output_attentions else None
  544. for i, layer_module in enumerate(self.layer):
  545. if output_hidden_states:
  546. all_hidden_states = all_hidden_states + (hidden_states,)
  547. layer_head_mask = head_mask[i] if head_mask is not None else None
  548. if self.gradient_checkpointing and self.training:
  549. layer_outputs = self._gradient_checkpointing_func(
  550. layer_module.__call__,
  551. hidden_states,
  552. attention_mask,
  553. layer_head_mask,
  554. output_attentions,
  555. )
  556. else:
  557. layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
  558. hidden_states = layer_outputs[0]
  559. if output_attentions:
  560. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  561. if output_hidden_states:
  562. all_hidden_states = all_hidden_states + (hidden_states,)
  563. if not return_dict:
  564. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  565. return BaseModelOutput(
  566. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
  567. )
  568. class FlavaPooler(nn.Module):
  569. def __init__(self, config: FlavaPossibleConfigs):
  570. super().__init__()
  571. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  572. self.activation = nn.Tanh()
  573. def forward(self, hidden_states: torch.Tensor):
  574. # We "pool" the model by simply taking the hidden state corresponding
  575. # to the first token.
  576. first_token_tensor = hidden_states[:, 0]
  577. pooled_output = self.dense(first_token_tensor)
  578. pooled_output = self.activation(pooled_output)
  579. return pooled_output
  580. FLAVA_START_DOCSTRING = r"""
  581. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  582. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  583. behavior.
  584. Parameters:
  585. config ([`{config}`]): Model configuration class with all the parameters of the model.
  586. Initializing with a config file does not load the weights associated with the model, only the
  587. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  588. """
  589. FLAVA_INPUTS_DOCSTRING_COMMON = r"""
  590. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  591. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  592. - 1 for tokens that are **not masked**,
  593. - 0 for tokens that are **masked**.
  594. [What are attention masks?](../glossary#attention-mask)
  595. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  596. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  597. - 1 indicates the head is **not masked**,
  598. - 0 indicates the head is **masked**.
  599. output_attentions (`bool`, *optional*):
  600. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  601. tensors for more detail.
  602. output_hidden_states (`bool`, *optional*):
  603. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  604. more detail.
  605. return_dict (`bool`, *optional*):
  606. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  607. """
  608. FLAVA_IMAGE_INPUTS_DOCSTRING_BASE = r"""
  609. Args:
  610. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  611. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  612. [`FlavaImageProcessor.__call__`] for details.
  613. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
  614. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  615. interpolate_pos_encoding (`bool`, *optional*):
  616. Whether to interpolate the pre-trained position encodings.
  617. """
  618. FLAVA_IMAGE_INPUTS_DOCSTRING = FLAVA_IMAGE_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
  619. FLAVA_TEXT_INPUTS_DOCSTRING_BASE = r"""
  620. Args:
  621. input_ids (`torch.LongTensor` of shape `({0})`):
  622. Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
  623. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
  624. IDs?](../glossary#input-ids)
  625. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  626. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  627. 1]`:
  628. - 0 corresponds to a *sentence A* token,
  629. - 1 corresponds to a *sentence B* token.
  630. [What are token type IDs?](../glossary#token-type-ids)
  631. """
  632. FLAVA_TEXT_INPUTS_DOCSTRING = FLAVA_TEXT_INPUTS_DOCSTRING_BASE + FLAVA_INPUTS_DOCSTRING_COMMON
  633. FLAVA_MULTIMODAL_INPUTS_DOCSTRING = (
  634. r"""
  635. Args:
  636. hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
  637. The concatenated hidden states of unimodal encoders.
  638. """
  639. + FLAVA_INPUTS_DOCSTRING_COMMON
  640. )
  641. FLAVA_MODEL_INPUTS_DOCSTRING_BASE = r"""
  642. Args:
  643. skip_multimodal_encoder (*bool*, *optional*):
  644. Skip any calculations for multimodal encoder. Useful if multimodal encoding is not going to be used.
  645. """
  646. FLAVA_MODEL_INPUTS_DOCSTRING = (
  647. FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
  648. + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
  649. + FLAVA_INPUTS_DOCSTRING_COMMON
  650. + FLAVA_MODEL_INPUTS_DOCSTRING_BASE
  651. )
  652. FLAVA_PRETRAINING_INPUTS_DOCSTRING = (
  653. r"""
  654. Args:
  655. input_ids_masked (`torch.LongTensor` of shape `({0})`):
  656. Indices of input sequence tokens in the vocabulary. These ones are the masked version of the original task
  657. to be used with MLM. Indices can be obtained using [`AutoTokenizer`] along with
  658. [`DataCollatorForMaskedLanguageModeling`]. See [`PreTrainedTokenizer.encode`] and
  659. [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
  660. """
  661. + FLAVA_TEXT_INPUTS_DOCSTRING_BASE
  662. + FLAVA_IMAGE_INPUTS_DOCSTRING_BASE
  663. + r"""
  664. image_attention_mask (`torch.FloatTensor` of shape `({1})`, *optional*):
  665. Mask to avoid performing attention on padding token indices specifically for images. Mask values selected
  666. in `[0, 1]`:
  667. - 1 for tokens that are **not masked**,
  668. - 0 for tokens that are **masked**.
  669. [What are attention masks?](../glossary#attention-mask)
  670. skip_unmasked_multimodal_encoder (*bool*, *optional*):
  671. Skip any calculations for multimodal encoder for unmasked inputs. FLAVA pretraining doesn't need unmasked
  672. multimodal embeddings or outputs as of now.
  673. mlm_labels (`torch.LongTensor` of shape `(batch_size, text_seq_len)`, *optional*):
  674. Labels for computing the left-to-right language and multimodal masked modeling loss (next word prediction).
  675. Indices should be in `[-100, 0, ..., text_config.vocab_size - 1]` (see `input_ids` docstring). Tokens with
  676. indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0,
  677. ..., text_config.vocab_size - 1]`.
  678. mim_labels (`torch.LongTensor` of shape `(batch_size, image_num_patches)`, *optional*):
  679. Labels for computing the image and multimodal masked modeling loss. Indices should be in `[-100, 0, ...,
  680. image_config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
  681. computed for the tokens with labels in `[0, ..., image_config.vocab_size - 1]`. If not passed, they are
  682. generated automatically using the image codebook assigned to the model. By default, it uses
  683. [`FlavaImageCodebook`]. See [`FlavaImageCodebook`] to understand how to generate mim_labels.
  684. itm_labels (`torch.LongTensor` of shape `(batch_size, 1)`, *optional*):
  685. Labels for computing the image-text matching loss. 0 means the pairs don't match and 1 means they match.
  686. The pairs with 0 will be skipped for calculation of MMM and global contrastive losses as well.
  687. return_loss (`bool`, *optional*, default to None):
  688. Whether to return calculated loss or not.
  689. """
  690. + FLAVA_INPUTS_DOCSTRING_COMMON
  691. )
  692. FLAVA_PRETRAINING_START_DOCSTRING_EXTRA = r"""
  693. Parameters:
  694. image_codebook ([`nn.Module`]): If passed, the image codebook will be set to this. Otherwise. it will
  695. be initialized using the image_codebook_config defined in the config first as the first parameter.
  696. """
  697. class FlavaPreTrainedModel(PreTrainedModel):
  698. """
  699. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  700. models.
  701. """
  702. config_class = FlavaConfig
  703. base_model_prefix = "flava"
  704. supports_gradient_checkpointing = True
  705. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  706. """Initialize the weights"""
  707. if isinstance(module, (nn.Linear, nn.Conv2d)):
  708. # Slightly different from the TF version which uses truncated_normal for initialization
  709. # cf https://github.com/pytorch/pytorch/pull/5617
  710. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  711. if module.bias is not None:
  712. module.bias.data.zero_()
  713. elif isinstance(module, nn.Embedding):
  714. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  715. if module.padding_idx is not None:
  716. module.weight.data[module.padding_idx].zero_()
  717. elif isinstance(module, nn.LayerNorm):
  718. module.bias.data.zero_()
  719. module.weight.data.fill_(1.0)
  720. @add_start_docstrings(
  721. "The bare FLAVA Image Model transformer outputting raw hidden-states without any specific head on top.",
  722. FLAVA_START_DOCSTRING.format(config="FlavaImageConfig"),
  723. )
  724. class FlavaImageModel(FlavaPreTrainedModel):
  725. config_class = FlavaImageConfig
  726. # This override allows us to load FlavaImageModel from FlavaModel/FlavaForPreTraining checkpoints.
  727. base_model_prefix = "flava.image_model"
  728. main_input_name = "pixel_values"
  729. def __init__(self, config: FlavaImageConfig, add_pooling_layer: bool = True):
  730. super().__init__(config)
  731. self.config = config
  732. self.embeddings = FlavaImageEmbeddings(config)
  733. self.encoder = FlavaEncoder(config)
  734. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  735. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  736. self.post_init()
  737. def get_input_embeddings(self) -> nn.Module:
  738. return self.embeddings.patch_embeddings
  739. def set_input_embeddings(self, value: nn.Module):
  740. self.embeddings.patch_embeddings = value
  741. def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
  742. """
  743. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  744. class PreTrainedModel
  745. """
  746. for layer, heads in heads_to_prune.items():
  747. self.encoder.layer[layer].attention.prune_heads(heads)
  748. @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
  749. @add_code_sample_docstrings(
  750. checkpoint=_CHECKPOINT_FOR_DOC,
  751. output_type=BaseModelOutputWithPooling,
  752. config_class=_CONFIG_CLASS_FOR_IMAGE_MODEL_DOC,
  753. modality="vision",
  754. expected_output=_EXPECTED_IMAGE_OUTPUT_SHAPE,
  755. )
  756. def forward(
  757. self,
  758. pixel_values: Optional[torch.Tensor] = None,
  759. bool_masked_pos: Optional[torch.BoolTensor] = None,
  760. interpolate_pos_encoding: Optional[bool] = None,
  761. attention_mask: Optional[torch.Tensor] = None,
  762. head_mask: Optional[torch.Tensor] = None,
  763. output_attentions: Optional[bool] = None,
  764. output_hidden_states: Optional[bool] = None,
  765. return_dict: Optional[bool] = None,
  766. ) -> Union[tuple, BaseModelOutputWithPooling]:
  767. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  768. output_hidden_states = (
  769. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  770. )
  771. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  772. if pixel_values is None:
  773. raise ValueError("You have to specify pixel_values")
  774. # Prepare head mask if needed
  775. # 1.0 in head_mask indicate we keep the head
  776. # attention_probs has shape bsz x n_heads x N x N
  777. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  778. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  779. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  780. embedding_output = self.embeddings(
  781. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  782. )
  783. encoder_outputs = self.encoder(
  784. embedding_output,
  785. attention_mask=attention_mask,
  786. head_mask=head_mask,
  787. output_attentions=output_attentions,
  788. output_hidden_states=output_hidden_states,
  789. return_dict=return_dict,
  790. )
  791. sequence_output = encoder_outputs[0]
  792. sequence_output = self.layernorm(sequence_output)
  793. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  794. if not return_dict:
  795. return (sequence_output, pooled_output) + encoder_outputs[1:]
  796. return BaseModelOutputWithPooling(
  797. last_hidden_state=sequence_output,
  798. pooler_output=pooled_output,
  799. hidden_states=encoder_outputs.hidden_states,
  800. attentions=encoder_outputs.attentions,
  801. )
  802. @add_start_docstrings(
  803. "The bare FLAVA Text Model transformer outputting raw hidden-states without any specific head on top.",
  804. FLAVA_START_DOCSTRING.format(config="FlavaTextConfig"),
  805. )
  806. class FlavaTextModel(FlavaPreTrainedModel):
  807. config_class = FlavaTextConfig
  808. # This override allows us to load FlavaTextModel from FlavaModel/FlavaForPreTraining checkpoints.
  809. base_model_prefix = "flava.text_model"
  810. def __init__(self, config: FlavaTextConfig, add_pooling_layer: bool = True):
  811. super().__init__(config)
  812. self.config = config
  813. self.embeddings = FlavaTextEmbeddings(config)
  814. self.encoder = FlavaEncoder(config)
  815. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  816. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  817. self.post_init()
  818. def get_input_embeddings(self) -> PatchEmbeddings:
  819. return self.embeddings.word_embeddings
  820. def set_input_embeddings(self, value: nn.Module):
  821. self.embeddings.word_embeddings = value
  822. def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
  823. """
  824. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  825. class PreTrainedModel
  826. """
  827. for layer, heads in heads_to_prune.items():
  828. self.encoder.layer[layer].attention.prune_heads(heads)
  829. @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
  830. @add_code_sample_docstrings(
  831. checkpoint=_CHECKPOINT_FOR_DOC,
  832. output_type=BaseModelOutputWithPooling,
  833. config_class=_CONFIG_CLASS_FOR_TEXT_MODEL_DOC,
  834. )
  835. def forward(
  836. self,
  837. input_ids: Optional[torch.Tensor] = None,
  838. attention_mask: Optional[torch.Tensor] = None,
  839. token_type_ids: Optional[torch.Tensor] = None,
  840. position_ids: Optional[torch.Tensor] = None,
  841. head_mask: Optional[torch.Tensor] = None,
  842. output_attentions: Optional[bool] = None,
  843. output_hidden_states: Optional[bool] = None,
  844. return_dict: Optional[bool] = None,
  845. ) -> Union[tuple, BaseModelOutputWithPooling]:
  846. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  847. output_hidden_states = (
  848. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  849. )
  850. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  851. if input_ids is None:
  852. raise ValueError("You have to specify input_ids")
  853. input_shape = input_ids.size()
  854. if attention_mask is None:
  855. attention_mask = torch.ones(input_shape, device=input_ids.device)
  856. # Prepare head mask if needed
  857. # 1.0 in head_mask indicate we keep the head
  858. # attention_probs has shape bsz x n_heads x N x N
  859. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  860. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  861. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  862. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  863. attention_mask, input_shape, input_ids.device
  864. )
  865. embedding_output = self.embeddings(
  866. input_ids=input_ids,
  867. token_type_ids=token_type_ids,
  868. position_ids=position_ids,
  869. )
  870. encoder_outputs = self.encoder(
  871. embedding_output,
  872. attention_mask=extended_attention_mask,
  873. head_mask=head_mask,
  874. output_attentions=output_attentions,
  875. output_hidden_states=output_hidden_states,
  876. return_dict=return_dict,
  877. )
  878. sequence_output = encoder_outputs[0]
  879. sequence_output = self.layernorm(sequence_output)
  880. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  881. if not return_dict:
  882. return (sequence_output, pooled_output) + encoder_outputs[1:]
  883. return BaseModelOutputWithPooling(
  884. last_hidden_state=sequence_output,
  885. pooler_output=pooled_output,
  886. hidden_states=encoder_outputs.hidden_states,
  887. attentions=encoder_outputs.attentions,
  888. )
  889. @add_start_docstrings(
  890. "The bare FLAVA Multimodal Model transformer outputting raw hidden-states without any specific head on top.",
  891. FLAVA_START_DOCSTRING.format(config="FlavaMultimodalConfig"),
  892. )
  893. class FlavaMultimodalModel(FlavaPreTrainedModel):
  894. config_class = FlavaMultimodalConfig
  895. # This override allows us to load FlavaMultimodalModel from FlavaModel/FlavaForPreTraining checkpoints.
  896. base_model_prefix = "flava.multimodal_model"
  897. main_input_name = "hidden_states"
  898. def __init__(self, config: FlavaMultimodalConfig, add_pooling_layer=True):
  899. super().__init__(config)
  900. self.config = config
  901. self.use_cls_token = self.config.use_cls_token
  902. if self.use_cls_token:
  903. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  904. self.encoder = FlavaEncoder(config)
  905. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  906. self.pooler = FlavaPooler(config) if add_pooling_layer else None
  907. self.post_init()
  908. def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
  909. """
  910. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  911. class PreTrainedModel
  912. """
  913. for layer, heads in heads_to_prune.items():
  914. self.encoder.layer[layer].attention.prune_heads(heads)
  915. @add_start_docstrings_to_model_forward(
  916. FLAVA_MULTIMODAL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
  917. )
  918. @add_code_sample_docstrings(
  919. checkpoint=_CHECKPOINT_FOR_DOC,
  920. output_type=BaseModelOutputWithPooling,
  921. config_class=_CONFIG_CLASS_FOR_MULTIMODAL_MODEL_DOC,
  922. )
  923. def forward(
  924. self,
  925. hidden_states: torch.Tensor,
  926. attention_mask: Optional[torch.Tensor] = None,
  927. head_mask: Optional[torch.Tensor] = None,
  928. output_attentions: Optional[bool] = None,
  929. output_hidden_states: Optional[bool] = None,
  930. return_dict: Optional[bool] = None,
  931. ) -> Union[tuple, BaseModelOutputWithPooling]:
  932. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  933. output_hidden_states = (
  934. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  935. )
  936. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  937. batch_size, seq_length, _ = hidden_states.size()
  938. if self.use_cls_token:
  939. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  940. hidden_states = torch.cat((cls_tokens, hidden_states), dim=1)
  941. seq_length += 1
  942. if attention_mask is None:
  943. attention_mask = torch.ones((batch_size, seq_length), device=hidden_states.device)
  944. # Prepare head mask if needed
  945. # 1.0 in head_mask indicate we keep the head
  946. # attention_probs has shape bsz x n_heads x N x N
  947. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  948. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  949. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  950. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  951. attention_mask, (batch_size, seq_length), hidden_states.device
  952. )
  953. encoder_outputs = self.encoder(
  954. hidden_states,
  955. attention_mask=extended_attention_mask,
  956. head_mask=head_mask,
  957. output_attentions=output_attentions,
  958. output_hidden_states=output_hidden_states,
  959. return_dict=return_dict,
  960. )
  961. sequence_output = encoder_outputs[0]
  962. sequence_output = self.layernorm(sequence_output)
  963. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  964. if not return_dict:
  965. return (sequence_output, pooled_output) + encoder_outputs[1:]
  966. return BaseModelOutputWithPooling(
  967. last_hidden_state=sequence_output,
  968. pooler_output=pooled_output,
  969. hidden_states=encoder_outputs.hidden_states,
  970. attentions=encoder_outputs.attentions,
  971. )
  972. @add_start_docstrings(
  973. "The bare FLAVA Model transformer outputting raw hidden-states without any specific head on top.",
  974. FLAVA_START_DOCSTRING.format(config="FlavaConfig"),
  975. )
  976. class FlavaModel(FlavaPreTrainedModel):
  977. config_class = FlavaConfig
  978. def __init__(self, config: FlavaConfig):
  979. super().__init__(config)
  980. if not isinstance(config.text_config, FlavaTextConfig):
  981. raise TypeError(
  982. "config.text_config is expected to be of type FlavaTextConfig but is of type"
  983. f" {type(config.text_config)}."
  984. )
  985. if not isinstance(config.image_config, FlavaImageConfig):
  986. raise TypeError(
  987. "config.image_config is expected to be of type FlavaImageConfig but is of type"
  988. f" {type(config.image_config)}."
  989. )
  990. if not isinstance(config.multimodal_config, FlavaMultimodalConfig):
  991. raise TypeError(
  992. "config.multimodal_config is expected to be of type FlavaMultimodalConfig but "
  993. + f"is of type {type(config.multimodal_config)}."
  994. )
  995. text_config = config.text_config
  996. image_config = config.image_config
  997. multimodal_config = config.multimodal_config
  998. self.projection_dim = config.projection_dim
  999. self.text_hidden_size = text_config.hidden_size
  1000. self.image_hidden_size = image_config.hidden_size
  1001. self.mm_hidden_size = multimodal_config.hidden_size
  1002. self.text_model = FlavaTextModel(text_config)
  1003. self.image_model = FlavaImageModel(image_config)
  1004. self.multimodal_model = FlavaMultimodalModel(multimodal_config)
  1005. self.image_projection = nn.Linear(self.image_hidden_size, self.projection_dim)
  1006. self.text_projection = nn.Linear(self.text_hidden_size, self.projection_dim)
  1007. self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
  1008. self.image_to_mm_projection = nn.Linear(self.image_hidden_size, self.mm_hidden_size)
  1009. self.text_to_mm_projection = nn.Linear(self.text_hidden_size, self.mm_hidden_size)
  1010. # Initialize weights and apply final processing
  1011. self.post_init()
  1012. @add_start_docstrings_to_model_forward(FLAVA_TEXT_INPUTS_DOCSTRING.format("batch_size, text_seq_length"))
  1013. def get_text_features(
  1014. self,
  1015. input_ids: Optional[torch.Tensor] = None,
  1016. attention_mask: Optional[torch.Tensor] = None,
  1017. token_type_ids: Optional[torch.Tensor] = None,
  1018. position_ids: Optional[torch.Tensor] = None,
  1019. output_attentions: Optional[bool] = None,
  1020. output_hidden_states: Optional[bool] = None,
  1021. return_dict: Optional[bool] = None,
  1022. ) -> torch.FloatTensor:
  1023. r"""
  1024. Returns:
  1025. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  1026. applying the projection layer to the pooled output of [`FlavaTextModel`].
  1027. Examples:
  1028. ```python
  1029. >>> from transformers import AutoProcessor, FlavaModel
  1030. >>> model = FlavaModel.from_pretrained("{0}")
  1031. >>> processor = AutoProcessor.from_pretrained("{0}")
  1032. >>> inputs = processor(
  1033. ... text=["a photo of a cat", "a photo of a dog"], max_length=77, padding="max_length", return_tensors="pt"
  1034. ... )
  1035. >>> text_features = model.get_text_features(**inputs)
  1036. ```""".format(_CHECKPOINT_FOR_DOC)
  1037. text_outputs = self.text_model(
  1038. input_ids=input_ids,
  1039. attention_mask=attention_mask,
  1040. token_type_ids=token_type_ids,
  1041. position_ids=position_ids,
  1042. output_attentions=output_attentions,
  1043. output_hidden_states=output_hidden_states,
  1044. return_dict=return_dict,
  1045. )
  1046. pooled_output = text_outputs[0] # last_hidden_state
  1047. text_features = self.text_projection(pooled_output)
  1048. return text_features
  1049. @add_start_docstrings_to_model_forward(FLAVA_IMAGE_INPUTS_DOCSTRING.format("batch_size, image_num_patches"))
  1050. def get_image_features(
  1051. self,
  1052. pixel_values: Optional[torch.Tensor] = None,
  1053. bool_masked_pos: Optional[torch.BoolTensor] = None,
  1054. interpolate_pos_encoding: Optional[bool] = None,
  1055. attention_mask: Optional[torch.Tensor] = None,
  1056. head_mask: Optional[torch.Tensor] = None,
  1057. output_attentions: Optional[bool] = None,
  1058. output_hidden_states: Optional[bool] = None,
  1059. return_dict: Optional[bool] = None,
  1060. ) -> torch.FloatTensor:
  1061. r"""
  1062. Returns:
  1063. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  1064. applying the projection layer to the pooled output of [`FlavaImageModel`].
  1065. Examples:
  1066. ```python
  1067. >>> from PIL import Image
  1068. >>> import requests
  1069. >>> from transformers import AutoProcessor, FlavaModel
  1070. >>> model = FlavaModel.from_pretrained("{0}")
  1071. >>> processor = AutoProcessor.from_pretrained("{0}")
  1072. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1073. >>> image = Image.open(requests.get(url, stream=True).raw)
  1074. >>> inputs = processor(images=image, return_tensors="pt")
  1075. >>> image_features = model.get_image_features(**inputs)
  1076. ```""".format(_CHECKPOINT_FOR_DOC)
  1077. image_outputs = self.image_model(
  1078. pixel_values=pixel_values,
  1079. bool_masked_pos=bool_masked_pos,
  1080. attention_mask=attention_mask,
  1081. head_mask=head_mask,
  1082. output_attentions=output_attentions,
  1083. output_hidden_states=output_hidden_states,
  1084. interpolate_pos_encoding=interpolate_pos_encoding,
  1085. return_dict=return_dict,
  1086. )
  1087. pooled_output = image_outputs[0] # last_hidden_state
  1088. image_features = self.image_projection(pooled_output)
  1089. return image_features
  1090. @add_start_docstrings_to_model_forward(
  1091. FLAVA_MODEL_INPUTS_DOCSTRING.format("batch_size, image_num_patches + text_seq_len")
  1092. )
  1093. @replace_return_docstrings(output_type=FlavaModelOutput, config_class=FlavaConfig)
  1094. def forward(
  1095. self,
  1096. input_ids: Optional[torch.LongTensor] = None,
  1097. pixel_values: Optional[torch.FloatTensor] = None,
  1098. attention_mask: Optional[torch.Tensor] = None,
  1099. token_type_ids: Optional[torch.Tensor] = None,
  1100. bool_masked_pos: Optional[torch.Tensor] = None,
  1101. position_ids: Optional[torch.LongTensor] = None,
  1102. image_attention_mask: Optional[torch.Tensor] = None,
  1103. skip_multimodal_encoder: Optional[bool] = None,
  1104. output_attentions: Optional[bool] = None,
  1105. output_hidden_states: bool = True,
  1106. return_dict: Optional[bool] = None,
  1107. ) -> Union[Tuple, FlavaOutput]:
  1108. r"""
  1109. Returns:
  1110. Examples:
  1111. ```python
  1112. >>> from PIL import Image
  1113. >>> import requests
  1114. >>> from transformers import AutoProcessor, FlavaModel
  1115. >>> model = FlavaModel.from_pretrained("facebook/flava-full")
  1116. >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
  1117. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1118. >>> image = Image.open(requests.get(url, stream=True).raw)
  1119. >>> inputs = processor(text=["a photo of a cat"], images=image, return_tensors="pt", padding=True)
  1120. >>> outputs = model(**inputs)
  1121. >>> image_embeddings = outputs.image_embeddings
  1122. >>> text_embeddings = outputs.text_embeddings
  1123. >>> multimodal_embeddings = outputs.multimodal_embeddings
  1124. >>> outputs.image_embeddings.shape
  1125. torch.Size([1, 197, 768])
  1126. >>> text_embeddings.shape
  1127. torch.Size([1, 7, 768])
  1128. >>> multimodal_embeddings.shape
  1129. torch.Size([1, 205, 768])
  1130. ```
  1131. """
  1132. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1133. if not output_hidden_states:
  1134. raise ValueError("FLAVA model requires hidden states to work. Please set `output_hidden_states=True`")
  1135. image_embeddings = None
  1136. image_states = None
  1137. image_mm_projection = None
  1138. image_output = None
  1139. if pixel_values is not None:
  1140. image_output = self.image_model(
  1141. pixel_values=pixel_values,
  1142. bool_masked_pos=bool_masked_pos,
  1143. attention_mask=image_attention_mask,
  1144. output_attentions=output_attentions,
  1145. output_hidden_states=output_hidden_states,
  1146. return_dict=return_dict,
  1147. )
  1148. image_embeddings, image_states = image_output[0], image_output[2]
  1149. # Note that these states don't use final layernorm in the transformer model
  1150. image_mm_projection = self.image_to_mm_projection(image_states[-1])
  1151. text_embeddings = None
  1152. text_states = None
  1153. text_mm_projection = None
  1154. text_output = None
  1155. if input_ids is not None:
  1156. text_output = self.text_model(
  1157. input_ids=input_ids,
  1158. attention_mask=attention_mask,
  1159. position_ids=position_ids,
  1160. token_type_ids=token_type_ids,
  1161. output_attentions=output_attentions,
  1162. output_hidden_states=output_hidden_states,
  1163. return_dict=return_dict,
  1164. )
  1165. text_embeddings, text_states = text_output[0], text_output[2]
  1166. # Note that these states don't use final layernorm in the transformer model
  1167. text_mm_projection = self.text_to_mm_projection(text_states[-1])
  1168. multimodal_embeddings = None
  1169. multimodal_output = None
  1170. if image_mm_projection is not None and text_mm_projection is not None and not skip_multimodal_encoder:
  1171. if attention_mask is not None:
  1172. batch_size, seq_len, _ = image_mm_projection.shape
  1173. if self.multimodal_model.use_cls_token:
  1174. seq_len += 1
  1175. attention_mask_image = torch.ones(batch_size, seq_len, device=image_mm_projection.device)
  1176. attention_multimodal = torch.cat([attention_mask_image, attention_mask], dim=1)
  1177. else:
  1178. attention_multimodal = None
  1179. multimodal_input = torch.cat([image_mm_projection, text_mm_projection], dim=1)
  1180. multimodal_output = self.multimodal_model(
  1181. multimodal_input, attention_mask=attention_multimodal, return_dict=return_dict
  1182. )
  1183. multimodal_embeddings = multimodal_output[0]
  1184. if not return_dict:
  1185. return (
  1186. image_embeddings,
  1187. image_output,
  1188. text_embeddings,
  1189. text_output,
  1190. multimodal_embeddings,
  1191. multimodal_output,
  1192. )
  1193. return FlavaModelOutput(
  1194. image_embeddings=image_embeddings,
  1195. image_output=image_output,
  1196. text_embeddings=text_embeddings,
  1197. text_output=text_output,
  1198. multimodal_embeddings=multimodal_embeddings,
  1199. multimodal_output=multimodal_output,
  1200. )
  1201. class FlavaImageCodebookResPath(nn.Module):
  1202. def __init__(self, in_size: int, out_size: int, **kwargs):
  1203. super().__init__()
  1204. hid_size = out_size // 4
  1205. path = OrderedDict()
  1206. path["relu_1"] = nn.ReLU()
  1207. path["conv_1"] = nn.Conv2d(in_size, hid_size, kernel_size=3, padding=1)
  1208. path["relu_2"] = nn.ReLU()
  1209. path["conv_2"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
  1210. path["relu_3"] = nn.ReLU()
  1211. path["conv_3"] = nn.Conv2d(hid_size, hid_size, kernel_size=3, padding=1)
  1212. path["relu_4"] = nn.ReLU()
  1213. path["conv_4"] = nn.Conv2d(hid_size, out_size, kernel_size=1, padding=0)
  1214. self.path = nn.Sequential(path)
  1215. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1216. return self.path(x)
  1217. class FlavaImageCodebookBlock(nn.Module):
  1218. def __init__(self, in_size: int, out_size: int, num_layers: int, **kwargs):
  1219. super().__init__()
  1220. self.post_gain = 1 / (num_layers**2)
  1221. if in_size != out_size:
  1222. self.id_path = nn.Conv2d(in_size, out_size, kernel_size=1, padding=0)
  1223. else:
  1224. self.id_path = nn.Identity()
  1225. self.res_path = FlavaImageCodebookResPath(in_size, out_size)
  1226. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1227. return self.id_path(x) + self.post_gain * self.res_path(x)
  1228. class FlavaImageCodebookLayerGroup(nn.Module):
  1229. def __init__(self, num_blocks: int, num_layers: int, in_size: int, out_size: int, use_pool: bool = True):
  1230. super().__init__()
  1231. blocks = OrderedDict()
  1232. for i in range(num_blocks):
  1233. if i == 0:
  1234. blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(in_size, out_size, num_layers)
  1235. else:
  1236. blocks[f"block_{i+1}"] = FlavaImageCodebookBlock(out_size, out_size, num_layers)
  1237. if use_pool:
  1238. blocks["pool"] = nn.MaxPool2d(kernel_size=2)
  1239. self.group = nn.Sequential(blocks)
  1240. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1241. return self.group(x)
  1242. # Inspired by DALLE Encoder in https://github.com/openai/DALL-E/blob/5be4b236bc3ade6943662354117a0e83752cc322/dall_e/encoder.py#L42
  1243. @add_start_docstrings(
  1244. """
  1245. The FLAVA's image codebook model inspired from DALL-E's original encoder. Outputs raw hidden states and can be used
  1246. to generate image tokens for an image based on DALL-E's vocab. Used to generate labels for MIM. Use
  1247. `get_codebook_indices` to get image tokens for an image.
  1248. """,
  1249. FLAVA_START_DOCSTRING.format(config="FlavaImageCodebookConfig"),
  1250. )
  1251. class FlavaImageCodebook(FlavaPreTrainedModel):
  1252. base_model_prefix = ""
  1253. config_class = FlavaImageCodebookConfig
  1254. main_input_name = "pixel_values"
  1255. supports_gradient_checkpointing = False
  1256. def __init__(
  1257. self,
  1258. config: FlavaImageCodebookConfig,
  1259. **kwargs: Any,
  1260. ):
  1261. super().__init__(config)
  1262. self.config = config
  1263. self.num_groups = config.num_groups
  1264. self.input_channels = config.input_channels
  1265. self.num_blocks_per_group = config.num_blocks_per_group
  1266. self.hidden_size = config.hidden_size
  1267. self.vocab_size = config.vocab_size
  1268. num_layers = self.num_groups * self.num_blocks_per_group
  1269. output_blocks = OrderedDict()
  1270. output_blocks["relu"] = nn.ReLU()
  1271. output_blocks["conv"] = nn.Conv2d(8 * self.hidden_size, self.vocab_size, kernel_size=1, padding=0)
  1272. blocks = OrderedDict()
  1273. blocks["input"] = nn.Conv2d(self.input_channels, 1 * self.hidden_size, kernel_size=7, padding=3)
  1274. blocks["group_1"] = FlavaImageCodebookLayerGroup(
  1275. self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 1 * self.hidden_size
  1276. )
  1277. blocks["group_2"] = FlavaImageCodebookLayerGroup(
  1278. self.num_blocks_per_group, num_layers, 1 * self.hidden_size, 2 * self.hidden_size
  1279. )
  1280. blocks["group_3"] = FlavaImageCodebookLayerGroup(
  1281. self.num_blocks_per_group, num_layers, 2 * self.hidden_size, 4 * self.hidden_size
  1282. )
  1283. blocks["group_4"] = FlavaImageCodebookLayerGroup(
  1284. self.num_blocks_per_group, num_layers, 4 * self.hidden_size, 8 * self.hidden_size, use_pool=False
  1285. )
  1286. blocks["output"] = nn.Sequential(output_blocks)
  1287. self.blocks = nn.Sequential(blocks)
  1288. self.post_init()
  1289. if self.config.freeze:
  1290. for param in self.parameters():
  1291. param.requires_grad = False
  1292. def get_codebook_indices(self, pixel_values: torch.Tensor) -> torch.Tensor:
  1293. """
  1294. Args:
  1295. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1296. Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
  1297. `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
  1298. Examples:
  1299. ```python
  1300. >>> from PIL import Image
  1301. >>> import requests
  1302. >>> from transformers import AutoImageProcessor, FlavaImageCodebook
  1303. >>> model = FlavaImageCodebook.from_pretrained("{0}")
  1304. >>> image_processor = AutoImageProcessor.from_pretrained("{0}")
  1305. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1306. >>> image = Image.open(requests.get(url, stream=True).raw)
  1307. >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
  1308. >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
  1309. >>> outputs = model.get_codebook_indices(**inputs)
  1310. ```
  1311. """.format(_CHECKPOINT_FOR_CODEBOOK_DOC)
  1312. z_logits = self.blocks(pixel_values)
  1313. return torch.argmax(z_logits, axis=1)
  1314. def get_codebook_probs(self, pixel_values: torch.Tensor) -> torch.Tensor:
  1315. z_logits = self.blocks(pixel_values)
  1316. return nn.Softmax(dim=1)(z_logits)
  1317. def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
  1318. """
  1319. Args:
  1320. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  1321. Pixel values. Codebook pixel values can be obtained using [`AutoImageProcessor`] by passing
  1322. `return_codebook_pixels=True`. See [`FlavaImageProcessor.__call__`] for details.
  1323. Examples:
  1324. ```python
  1325. >>> from PIL import Image
  1326. >>> import requests
  1327. >>> from transformers import AutoImageProcessor, FlavaImageCodebook
  1328. >>> model = FlavaImageCodebook.from_pretrained("{0}")
  1329. >>> image_processor = AutoImageProcessor.from_pretrained("{0}")
  1330. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1331. >>> image = Image.open(requests.get(url, stream=True).raw)
  1332. >>> inputs = image_processor([image], return_codebook_pixels=True, return_tensors="pt")
  1333. >>> inputs = dict(pixel_values=inputs.codebook_pixel_values)
  1334. >>> outputs = model(**inputs)
  1335. >>> print(outputs.shape)
  1336. (1, 196)
  1337. ```
  1338. """.format(_CHECKPOINT_FOR_CODEBOOK_DOC)
  1339. if len(pixel_values.shape) != 4:
  1340. raise ValueError(f"input shape {pixel_values.shape} is not 4d")
  1341. if pixel_values.shape[1] != self.input_channels:
  1342. raise ValueError(f"input has {pixel_values.shape[1]} channels but model built for {self.input_channels}")
  1343. return self.blocks(pixel_values)
  1344. class FlavaPredictionHeadTransform(nn.Module):
  1345. def __init__(self, config):
  1346. super().__init__()
  1347. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1348. if isinstance(config.hidden_act, str):
  1349. self.transform_act_fn = ACT2FN[config.hidden_act]
  1350. else:
  1351. self.transform_act_fn = config.hidden_act
  1352. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1353. def forward(self, hidden_states):
  1354. hidden_states = self.dense(hidden_states)
  1355. hidden_states = self.transform_act_fn(hidden_states)
  1356. hidden_states = self.LayerNorm(hidden_states)
  1357. return hidden_states
  1358. class FlavaMaskedPredictionHead(nn.Module):
  1359. def __init__(self, config, weight=None):
  1360. super().__init__()
  1361. self.config = config
  1362. self.transform = FlavaPredictionHeadTransform(config)
  1363. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1364. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  1365. if weight is not None:
  1366. self.decoder.weight = weight
  1367. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  1368. self.decoder.bias = self.bias
  1369. def _tie_weights(self):
  1370. self.decoder.bias = self.bias
  1371. def forward(self, x):
  1372. x = self.transform(x)
  1373. x = self.decoder(x)
  1374. return x
  1375. class FlavaITMHead(nn.Module):
  1376. def __init__(self, config):
  1377. super().__init__()
  1378. self.config = config
  1379. self.pooler = FlavaPooler(config)
  1380. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  1381. def forward(self, x):
  1382. x = self.pooler(x)
  1383. x = self.seq_relationship(x)
  1384. return x
  1385. class FlavaGlobalContrastiveHead(nn.Module):
  1386. def __init__(self, config):
  1387. super().__init__()
  1388. self.config = config
  1389. self.global_backprop_contrastive = config.global_backprop_contrastive
  1390. def forward(self, image_embeddings, text_embeddings, logit_scale):
  1391. temperature = torch.exp(logit_scale)
  1392. if not torch.distributed.is_available() or not torch.distributed.is_initialized():
  1393. labels = torch.arange(image_embeddings.size(0), device=image_embeddings.device)
  1394. image_embeddings_all = [image_embeddings]
  1395. text_embeddings_all = [text_embeddings]
  1396. else:
  1397. local_batch_size = image_embeddings.size(0)
  1398. world_size = torch.distributed.get_world_size()
  1399. if self.global_backprop_contrastive:
  1400. # `torch.distributed.nn.functional.all_gather` does backprop on all active workers
  1401. # whereas `torch.distributed.all_gather` does only backpropagates on the current worker.
  1402. image_embeddings_all = torch.distributed.nn.functional.all_gather(image_embeddings)
  1403. text_embeddings_all = torch.distributed.nn.functional.all_gather(text_embeddings)
  1404. else:
  1405. image_embeddings_all = [torch.zeros_like(text_embeddings) for _ in range(world_size)]
  1406. text_embeddings_all = [torch.zeros_like(image_embeddings) for _ in range(world_size)]
  1407. torch.distributed.all_gather(image_embeddings_all, image_embeddings)
  1408. torch.distributed.all_gather(text_embeddings_all, text_embeddings)
  1409. labels = local_batch_size * torch.distributed.get_rank() + torch.arange(
  1410. local_batch_size, device=image_embeddings.device
  1411. )
  1412. image_embeddings_all = torch.cat(image_embeddings_all)
  1413. text_embeddings_all = torch.cat(text_embeddings_all)
  1414. logits_per_image = torch.matmul(image_embeddings, text_embeddings_all.transpose(0, 1)) * temperature
  1415. logits_per_text = torch.matmul(text_embeddings, image_embeddings_all.transpose(0, 1)) * temperature
  1416. return logits_per_image, logits_per_text, labels
  1417. @add_start_docstrings(
  1418. """
  1419. The FLAVA model for pretraining which outputs losses, embeddings, logits and transformer outputs.
  1420. """,
  1421. FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
  1422. )
  1423. class FlavaForPreTraining(FlavaPreTrainedModel):
  1424. # Those are linked to xxx.bias
  1425. _tied_weights_keys = [
  1426. "mmm_text_head.decoder.bias",
  1427. "mmm_image_head.decoder.bias",
  1428. "mlm_head.decoder.bias",
  1429. "mim_head.decoder.bias",
  1430. ]
  1431. def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
  1432. super().__init__(config)
  1433. self.flava = FlavaModel(config)
  1434. self.image_codebook = image_codebook
  1435. if self.image_codebook is None and config.init_codebook:
  1436. self.image_codebook = FlavaImageCodebook(config.image_codebook_config)
  1437. # Levarage text and image encoder configs to create the masked
  1438. # head since it has the right vocab
  1439. self.mim_head = FlavaMaskedPredictionHead(config.image_config)
  1440. self.mlm_head = FlavaMaskedPredictionHead(config.text_config)
  1441. self.itm_head = FlavaITMHead(config)
  1442. self.mmm_image_head = FlavaMaskedPredictionHead(config.image_config)
  1443. self.mmm_text_head = FlavaMaskedPredictionHead(config.text_config)
  1444. self.global_contrastive_head = FlavaGlobalContrastiveHead(config)
  1445. self.image_vocab_size = config.image_config.vocab_size
  1446. self.text_vocab_size = config.text_config.vocab_size
  1447. self.mlm_weight = config.mlm_weight
  1448. self.mim_weight = config.mim_weight
  1449. self.global_contrastive_weight = config.global_contrastive_weight
  1450. self.ce_ignore_index = config.ce_ignore_index
  1451. self.itm_weight = config.itm_weight
  1452. self.mmm_image_weight = config.mmm_image_weight
  1453. self.mmm_text_weight = config.mmm_text_weight
  1454. self.skip_unmasked_multimodal_encoder = config.skip_unmasked_multimodal_encoder
  1455. self.post_init()
  1456. def _resize_to_2d(self, x: torch.Tensor):
  1457. if x.dim() > 2:
  1458. x = x.view(x.size(0), -1)
  1459. return x
  1460. @add_start_docstrings_to_model_forward(
  1461. FLAVA_PRETRAINING_INPUTS_DOCSTRING.format("batch_size, text_seq_len", "batch_size, image_num_patches")
  1462. )
  1463. @replace_return_docstrings(output_type=FlavaForPreTrainingOutput, config_class=FlavaConfig)
  1464. def forward(
  1465. self,
  1466. input_ids: Optional[torch.LongTensor] = None,
  1467. input_ids_masked: Optional[torch.LongTensor] = None,
  1468. pixel_values: Optional[torch.FloatTensor] = None,
  1469. codebook_pixel_values: Optional[torch.FloatTensor] = None,
  1470. attention_mask: Optional[torch.Tensor] = None,
  1471. token_type_ids: Optional[torch.Tensor] = None,
  1472. bool_masked_pos: Optional[torch.Tensor] = None,
  1473. position_ids: Optional[torch.LongTensor] = None,
  1474. image_attention_mask: Optional[torch.Tensor] = None,
  1475. skip_unmasked_multimodal_encoder: bool = None,
  1476. mlm_labels: Optional[torch.Tensor] = None,
  1477. mim_labels: Optional[torch.Tensor] = None,
  1478. itm_labels: Optional[torch.Tensor] = None,
  1479. output_attentions: Optional[bool] = None,
  1480. output_hidden_states: bool = True,
  1481. return_dict: Optional[bool] = None,
  1482. return_loss: Optional[bool] = None,
  1483. ) -> Union[Tuple[torch.Tensor], FlavaForPreTrainingOutput]:
  1484. """
  1485. Examples:
  1486. ```python
  1487. >>> from PIL import Image
  1488. >>> import requests
  1489. >>> from transformers import FlavaForPreTraining, AutoProcessor
  1490. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1491. >>> image = Image.open(requests.get(url, stream=True).raw)
  1492. >>> model = FlavaForPreTraining.from_pretrained("facebook/flava-full")
  1493. >>> processor = AutoProcessor.from_pretrained("facebook/flava-full")
  1494. >>> text = ["a photo of a cat"]
  1495. >>> inputs = processor(
  1496. ... images=[image],
  1497. ... text=text,
  1498. ... return_masks=True,
  1499. ... return_codebook_pixels=True,
  1500. ... padding=True,
  1501. ... max_length=77,
  1502. ... return_tensors="pt",
  1503. ... )
  1504. >>> output = model(**inputs)
  1505. ```
  1506. Return:
  1507. """
  1508. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1509. return_loss = return_loss if return_loss is not None else self.config.return_loss
  1510. skip_unmasked_multimodal_encoder = (
  1511. skip_unmasked_multimodal_encoder
  1512. if skip_unmasked_multimodal_encoder is not None
  1513. else self.skip_unmasked_multimodal_encoder
  1514. )
  1515. if input_ids_masked is None and input_ids is not None:
  1516. logger.warning(
  1517. "`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to"
  1518. " `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if"
  1519. " you are doing inference on unmasked text..."
  1520. )
  1521. input_ids_masked = input_ids
  1522. flava_output = self.flava(
  1523. input_ids=input_ids,
  1524. pixel_values=pixel_values,
  1525. attention_mask=attention_mask,
  1526. token_type_ids=token_type_ids,
  1527. position_ids=position_ids,
  1528. image_attention_mask=image_attention_mask,
  1529. # Don't need unmasked multimodal embedding for anything so skip it
  1530. # NOTE: ITM uses masked version
  1531. skip_multimodal_encoder=skip_unmasked_multimodal_encoder,
  1532. output_attentions=output_attentions,
  1533. output_hidden_states=output_hidden_states,
  1534. # Pass true to have deterministic outputs
  1535. return_dict=True,
  1536. )
  1537. flava_masked_output = self.flava(
  1538. input_ids=input_ids_masked,
  1539. pixel_values=pixel_values,
  1540. attention_mask=attention_mask,
  1541. token_type_ids=token_type_ids,
  1542. image_attention_mask=image_attention_mask,
  1543. bool_masked_pos=bool_masked_pos,
  1544. output_attentions=output_attentions,
  1545. output_hidden_states=output_hidden_states,
  1546. return_dict=True,
  1547. )
  1548. pos_mask = None
  1549. image_embeddings = flava_output.image_embeddings
  1550. text_embeddings = flava_output.text_embeddings
  1551. image_masked_embeddings = flava_masked_output.image_embeddings
  1552. text_masked_embeddings = flava_masked_output.text_embeddings
  1553. multimodal_masked_embeddings = flava_masked_output.multimodal_embeddings
  1554. total_loss = mim_loss = mlm_loss = mmm_text_loss = mmm_image_loss = gc_loss = itm_loss = None
  1555. mim_logits = mlm_logits = mmm_text_logits = mmm_image_logits = None
  1556. itm_logits = logits_per_image = logits_per_text = None
  1557. # Calculate mim_labels if necessary from the image_codebook
  1558. if image_masked_embeddings is not None or multimodal_masked_embeddings is not None:
  1559. if mim_labels is None and return_loss:
  1560. if self.image_codebook is None:
  1561. raise RuntimeError(
  1562. "`return_loss` is set to True but the image codebook is not initialized and no `mim_labels` "
  1563. " have been passed. Reinstantiate the model with `init_codebook` set to True or "
  1564. "pass in your custom `mim_labels`"
  1565. )
  1566. if codebook_pixel_values is None:
  1567. raise ValueError(
  1568. "`codebook_pixel_value` are required to generate `mim_labels` if loss is expected. "
  1569. "Call `AutoProcessor` with `return_codebook_pixels` set to True"
  1570. )
  1571. mim_labels = self.image_codebook.get_codebook_indices(codebook_pixel_values)
  1572. # Unimodal MIM Loss
  1573. # If multimodal embeddings are present, we will calculate MMM loss
  1574. if self.mim_weight > 0 and image_masked_embeddings is not None and multimodal_masked_embeddings is None:
  1575. sequence_for_image = image_masked_embeddings
  1576. if mim_labels is not None:
  1577. mim_labels = self._resize_to_2d(mim_labels)
  1578. bool_masked_pos = self._resize_to_2d(bool_masked_pos)
  1579. mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
  1580. sequence_for_image = sequence_for_image[:, -mim_labels.size(1) :, :]
  1581. masked_tokens = mim_labels.ne(self.ce_ignore_index)
  1582. mim_labels_filtered = mim_labels[masked_tokens]
  1583. sequence_for_image = sequence_for_image[masked_tokens, :]
  1584. mim_logits = self.mim_head(sequence_for_image)
  1585. if return_loss:
  1586. mim_loss = nn.functional.cross_entropy(
  1587. mim_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
  1588. )
  1589. mim_loss *= self.mim_weight
  1590. else:
  1591. mim_logits = self.mim_head(sequence_for_image)
  1592. # Unimodal MLM Loss
  1593. if self.mlm_weight > 0 and text_masked_embeddings is not None and multimodal_masked_embeddings is None:
  1594. sequence_for_text = text_masked_embeddings
  1595. if mlm_labels is not None:
  1596. mlm_labels = self._resize_to_2d(mlm_labels)
  1597. sequence_for_text = sequence_for_text[:, -mlm_labels.size(1) :, :]
  1598. masked_tokens = mlm_labels.ne(self.ce_ignore_index)
  1599. mlm_labels_filtered = mlm_labels[masked_tokens]
  1600. sequence_for_text = sequence_for_text[masked_tokens, :]
  1601. mlm_logits = self.mlm_head(sequence_for_text)
  1602. if return_loss:
  1603. mlm_loss = nn.functional.cross_entropy(
  1604. mlm_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
  1605. )
  1606. mlm_loss *= self.mlm_weight
  1607. else:
  1608. mlm_logits = self.mlm_head(sequence_for_text)
  1609. # ITM Loss
  1610. if self.itm_weight > 0 and multimodal_masked_embeddings is not None:
  1611. itm_logits = self.itm_head(multimodal_masked_embeddings)
  1612. if itm_labels is not None:
  1613. pos_pairs = itm_labels.ne(0)
  1614. pos_mask = torch.where(pos_pairs.any(), pos_pairs, pos_pairs.new([True]))
  1615. if return_loss:
  1616. itm_loss = nn.functional.cross_entropy(itm_logits, itm_labels)
  1617. itm_loss *= self.itm_weight
  1618. if multimodal_masked_embeddings is not None:
  1619. multimodal_masked_embeddings = multimodal_masked_embeddings[pos_mask]
  1620. if mlm_labels is not None:
  1621. mlm_labels = mlm_labels[pos_mask]
  1622. if mim_labels is not None:
  1623. mim_labels = mim_labels[pos_mask]
  1624. bool_masked_pos = bool_masked_pos[pos_mask]
  1625. # MMM Image Loss
  1626. if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0:
  1627. sequence_for_image = multimodal_masked_embeddings
  1628. end_index = image_masked_embeddings.size(1) - 1
  1629. sequence_for_image = sequence_for_image[:, 2 : 2 + end_index, :]
  1630. if mim_labels is not None:
  1631. mim_labels = self._resize_to_2d(mim_labels)
  1632. bool_masked_pos = self._resize_to_2d(bool_masked_pos)
  1633. mim_labels[bool_masked_pos.ne(True)] = self.ce_ignore_index
  1634. masked_tokens = mim_labels.ne(self.ce_ignore_index)
  1635. mim_labels_filtered = mim_labels[masked_tokens]
  1636. sequence_for_image = sequence_for_image[masked_tokens, :]
  1637. mmm_image_logits = self.mmm_image_head(sequence_for_image)
  1638. if return_loss:
  1639. mmm_image_loss = nn.functional.cross_entropy(
  1640. mmm_image_logits.view(-1, self.image_vocab_size), mim_labels_filtered.view(-1)
  1641. )
  1642. mmm_image_loss *= self.mmm_image_weight
  1643. else:
  1644. mmm_image_logits = self.mmm_image_head(sequence_for_image)
  1645. # MMM Text Loss
  1646. if multimodal_masked_embeddings is not None and self.mmm_text_weight > 0:
  1647. sequence_for_text = multimodal_masked_embeddings
  1648. sequence_for_text = sequence_for_text[:, -text_masked_embeddings.size(1) :, :]
  1649. if mlm_labels is not None:
  1650. mlm_labels = self._resize_to_2d(mlm_labels)
  1651. masked_tokens = mlm_labels.ne(self.ce_ignore_index)
  1652. mlm_labels_filtered = mlm_labels[masked_tokens]
  1653. sequence_for_text = sequence_for_text[masked_tokens, :]
  1654. mmm_text_logits = self.mmm_text_head(sequence_for_text)
  1655. if return_loss:
  1656. mmm_text_loss = nn.functional.cross_entropy(
  1657. mmm_text_logits.view(-1, self.text_vocab_size), mlm_labels_filtered.view(-1)
  1658. )
  1659. mmm_text_loss *= self.mmm_text_weight
  1660. else:
  1661. mmm_text_logits = self.mmm_text_head(sequence_for_text)
  1662. # Global Contrastive Loss
  1663. if image_embeddings is not None and text_embeddings is not None and self.global_contrastive_weight > 0:
  1664. text_embedding = self.flava.text_projection(text_embeddings[:, 0, :])
  1665. text_embedding = nn.functional.normalize(text_embedding, dim=-1)
  1666. image_embedding = self.flava.image_projection(image_embeddings[:, 0, :])
  1667. image_embedding = nn.functional.normalize(image_embedding, dim=-1)
  1668. self.flava.logit_scale.data.clamp_(LOGIT_SCALE_CLAMP_MIN, LOGIT_SCALE_CLAMP_MAX)
  1669. logits_per_image, logits_per_text, gc_labels = self.global_contrastive_head(
  1670. image_embedding, text_embedding, self.flava.logit_scale
  1671. )
  1672. # Apply ITM negative mask if any
  1673. if pos_mask is not None:
  1674. logits_per_image = logits_per_image[pos_mask]
  1675. logits_per_text = logits_per_text[pos_mask]
  1676. gc_labels = gc_labels[pos_mask]
  1677. if return_loss:
  1678. gc_loss_image = nn.functional.cross_entropy(logits_per_image, gc_labels)
  1679. gc_loss_text = nn.functional.cross_entropy(logits_per_text, gc_labels)
  1680. gc_loss = (gc_loss_image + gc_loss_text) / 2
  1681. gc_loss *= self.global_contrastive_weight
  1682. flava_losses = FlavaLosses(
  1683. mim=mim_loss,
  1684. mlm=mlm_loss,
  1685. itm=itm_loss,
  1686. global_contrastive=gc_loss,
  1687. mmm_image=mmm_image_loss,
  1688. mmm_text=mmm_text_loss,
  1689. )
  1690. if return_loss and not flava_losses.all_none():
  1691. total_loss = sum(loss if loss is not None else 0 for loss in flava_losses.values())
  1692. if not return_dict:
  1693. output = (
  1694. image_embeddings,
  1695. flava_output.image_output.to_tuple() if flava_output.image_output is not None else None,
  1696. text_embeddings,
  1697. flava_output.text_output.to_tuple() if flava_output.text_output is not None else None,
  1698. flava_output.multimodal_embeddings,
  1699. flava_output.multimodal_output.to_tuple() if flava_output.multimodal_output is not None else None,
  1700. image_masked_embeddings,
  1701. flava_masked_output.image_output.to_tuple() if flava_masked_output.image_output is not None else None,
  1702. text_masked_embeddings,
  1703. flava_masked_output.text_output.to_tuple() if flava_masked_output.text_output is not None else None,
  1704. multimodal_masked_embeddings,
  1705. flava_masked_output.multimodal_output.to_tuple()
  1706. if flava_masked_output.multimodal_output is not None
  1707. else None,
  1708. mim_logits,
  1709. mlm_logits,
  1710. itm_logits,
  1711. logits_per_image,
  1712. logits_per_image,
  1713. mmm_image_logits,
  1714. mmm_text_logits,
  1715. )
  1716. if return_loss and not flava_losses.all_none():
  1717. output = (
  1718. total_loss,
  1719. flava_losses,
  1720. ) + output
  1721. # Filter None as transformer by default won't handle it
  1722. return tuple(x for x in output if x is None)
  1723. return FlavaForPreTrainingOutput(
  1724. loss=total_loss,
  1725. loss_info=flava_losses,
  1726. image_embeddings=image_embeddings,
  1727. image_output=flava_output.image_output,
  1728. text_embeddings=text_embeddings,
  1729. text_output=flava_output.text_output,
  1730. multimodal_embeddings=flava_output.multimodal_embeddings,
  1731. multimodal_output=flava_output.multimodal_output,
  1732. image_masked_embeddings=image_masked_embeddings,
  1733. image_masked_output=flava_masked_output.image_output,
  1734. text_masked_embeddings=text_masked_embeddings,
  1735. text_masked_output=flava_masked_output.text_output,
  1736. multimodal_masked_embeddings=multimodal_masked_embeddings,
  1737. multimodal_masked_output=flava_masked_output.multimodal_output,
  1738. mim_logits=mim_logits,
  1739. mlm_logits=mlm_logits,
  1740. itm_logits=itm_logits,
  1741. contrastive_logits_per_image=logits_per_image,
  1742. contrastive_logits_per_text=logits_per_text,
  1743. mmm_image_logits=mmm_image_logits,
  1744. mmm_text_logits=mmm_text_logits,
  1745. )