modeling_align.py 70 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641
  1. # coding=utf-8
  2. # Copyright 2023 The Google Research Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ALIGN model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithNoAttention,
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. BaseModelOutputWithPoolingAndNoAttention,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  31. from ...utils import (
  32. ModelOutput,
  33. add_start_docstrings,
  34. add_start_docstrings_to_model_forward,
  35. logging,
  36. replace_return_docstrings,
  37. )
  38. from .configuration_align import AlignConfig, AlignTextConfig, AlignVisionConfig
  39. logger = logging.get_logger(__name__)
  40. _CHECKPOINT_FOR_DOC = "kakaobrain/align-base"
  41. _CONFIG_FOR_DOC = "AlignConfig"
  42. ALIGN_START_DOCSTRING = r"""
  43. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  44. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  45. etc.)
  46. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  47. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  48. and behavior.
  49. Parameters:
  50. config ([`AlignConfig`]): Model configuration class with all the parameters of the model.
  51. Initializing with a config file does not load the weights associated with the model, only the
  52. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  53. """
  54. ALIGN_TEXT_INPUTS_DOCSTRING = r"""
  55. Args:
  56. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  57. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  58. it.
  59. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  60. [`PreTrainedTokenizer.__call__`] for details.
  61. [What are input IDs?](../glossary#input-ids)
  62. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  63. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  64. - 1 for tokens that are **not masked**,
  65. - 0 for tokens that are **masked**.
  66. [What are attention masks?](../glossary#attention-mask)
  67. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  68. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  69. config.max_position_embeddings - 1]`.
  70. [What are position IDs?](../glossary#position-ids)
  71. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  72. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  73. 1]`:
  74. - 0 corresponds to a *sentence A* token,
  75. - 1 corresponds to a *sentence B* token.
  76. [What are token type IDs?](../glossary#token-type-ids)
  77. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  78. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  79. - 1 indicates the head is **not masked**,
  80. - 0 indicates the head is **masked**.
  81. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  82. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  83. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  84. model's internal embedding lookup matrix.
  85. output_attentions (`bool`, *optional*):
  86. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  87. tensors for more detail.
  88. output_hidden_states (`bool`, *optional*):
  89. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  90. more detail.
  91. return_dict (`bool`, *optional*):
  92. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  93. """
  94. ALIGN_VISION_INPUTS_DOCSTRING = r"""
  95. Args:
  96. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  97. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  98. [`AutoImageProcessor`]. See [`EfficientNetImageProcessor.__call__`] for details.
  99. output_hidden_states (`bool`, *optional*):
  100. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  101. more detail.
  102. return_dict (`bool`, *optional*):
  103. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  104. """
  105. ALIGN_INPUTS_DOCSTRING = r"""
  106. Args:
  107. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  108. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  109. it.
  110. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  111. [`PreTrainedTokenizer.__call__`] for details.
  112. [What are input IDs?](../glossary#input-ids)
  113. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  114. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  115. - 1 for tokens that are **not masked**,
  116. - 0 for tokens that are **masked**.
  117. [What are attention masks?](../glossary#attention-mask)
  118. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  119. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  120. config.max_position_embeddings - 1]`.
  121. [What are position IDs?](../glossary#position-ids)
  122. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  123. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  124. 1]`:
  125. - 0 corresponds to a *sentence A* token,
  126. - 1 corresponds to a *sentence B* token.
  127. [What are token type IDs?](../glossary#token-type-ids)
  128. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  129. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  130. - 1 indicates the head is **not masked**,
  131. - 0 indicates the head is **masked**.
  132. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  133. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  134. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  135. model's internal embedding lookup matrix.
  136. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  137. Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
  138. [`AutoImageProcessor`]. See [`EfficientNetImageProcessor.__call__`] for details.
  139. return_loss (`bool`, *optional*):
  140. Whether or not to return the contrastive loss.
  141. output_attentions (`bool`, *optional*):
  142. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  143. tensors for more detail.
  144. output_hidden_states (`bool`, *optional*):
  145. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  146. more detail.
  147. return_dict (`bool`, *optional*):
  148. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  149. """
  150. @dataclass
  151. class AlignVisionModelOutput(ModelOutput):
  152. """
  153. Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
  154. Args:
  155. image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  156. The image embeddings obtained by applying the projection layer to the pooler_output.
  157. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  158. Sequence of hidden-states at the output of the last layer of the model.
  159. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  160. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  161. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  162. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  163. """
  164. image_embeds: Optional[torch.FloatTensor] = None
  165. last_hidden_state: torch.FloatTensor = None
  166. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  167. @dataclass
  168. class AlignTextModelOutput(ModelOutput):
  169. """
  170. Base class for text model's outputs that also contains a pooling of the last hidden states.
  171. Args:
  172. text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
  173. The text embeddings obtained by applying the projection layer to the pooler_output.
  174. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  175. Sequence of hidden-states at the output of the last layer of the model.
  176. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  177. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  178. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  179. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  180. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  181. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  182. sequence_length)`.
  183. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  184. heads.
  185. """
  186. text_embeds: Optional[torch.FloatTensor] = None
  187. last_hidden_state: torch.FloatTensor = None
  188. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  189. attentions: Optional[Tuple[torch.FloatTensor]] = None
  190. @dataclass
  191. class AlignOutput(ModelOutput):
  192. """
  193. Args:
  194. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
  195. Contrastive loss for image-text similarity.
  196. logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
  197. The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
  198. similarity scores.
  199. logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
  200. The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
  201. similarity scores.
  202. text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
  203. The text embeddings obtained by applying the projection layer to the pooled output of [`AlignTextModel`].
  204. image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
  205. The output of [`AlignVisionModel`].
  206. text_model_output(`BaseModelOutputWithPoolingAndCrossAttentions`):
  207. The output of the [`AlignTextModel`].
  208. vision_model_output(`BaseModelOutputWithPoolingAndNoAttention`):
  209. The output of the [`AlignVisionModel`].
  210. """
  211. loss: Optional[torch.FloatTensor] = None
  212. logits_per_image: torch.FloatTensor = None
  213. logits_per_text: torch.FloatTensor = None
  214. text_embeds: torch.FloatTensor = None
  215. image_embeds: torch.FloatTensor = None
  216. text_model_output: BaseModelOutputWithPoolingAndCrossAttentions = None
  217. vision_model_output: BaseModelOutputWithPoolingAndNoAttention = None
  218. def to_tuple(self) -> Tuple[Any]:
  219. return tuple(
  220. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  221. for k in self.keys()
  222. )
  223. # contrastive loss function, adapted from
  224. # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html
  225. def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
  226. return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device), label_smoothing=0.1)
  227. def align_loss(similarity: torch.Tensor) -> torch.Tensor:
  228. caption_loss = contrastive_loss(similarity)
  229. image_loss = contrastive_loss(similarity.t())
  230. return (caption_loss + image_loss) / 2.0
  231. # Copied from transformers.models.efficientnet.modeling_efficientnet.round_filters with EfficientNet->AlignVision
  232. def round_filters(config: AlignVisionConfig, num_channels: int):
  233. r"""
  234. Round number of filters based on depth multiplier.
  235. """
  236. divisor = config.depth_divisor
  237. num_channels *= config.width_coefficient
  238. new_dim = max(divisor, int(num_channels + divisor / 2) // divisor * divisor)
  239. # Make sure that round down does not go down by more than 10%.
  240. if new_dim < 0.9 * num_channels:
  241. new_dim += divisor
  242. return int(new_dim)
  243. # Copied from transformers.models.efficientnet.modeling_efficientnet.correct_pad
  244. def correct_pad(kernel_size: Union[int, Tuple], adjust: bool = True):
  245. r"""
  246. Utility function to get the tuple padding value for the depthwise convolution.
  247. Args:
  248. kernel_size (`int` or `tuple`):
  249. Kernel size of the convolution layers.
  250. adjust (`bool`, *optional*, defaults to `True`):
  251. Adjusts padding value to apply to right and bottom sides of the input.
  252. """
  253. if isinstance(kernel_size, int):
  254. kernel_size = (kernel_size, kernel_size)
  255. correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  256. if adjust:
  257. return (correct[1] - 1, correct[1], correct[0] - 1, correct[0])
  258. else:
  259. return (correct[1], correct[1], correct[0], correct[0])
  260. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetEmbeddings with EfficientNet->AlignVision
  261. class AlignVisionEmbeddings(nn.Module):
  262. r"""
  263. A module that corresponds to the stem module of the original work.
  264. """
  265. def __init__(self, config: AlignVisionConfig):
  266. super().__init__()
  267. self.out_dim = round_filters(config, 32)
  268. self.padding = nn.ZeroPad2d(padding=(0, 1, 0, 1))
  269. self.convolution = nn.Conv2d(
  270. config.num_channels, self.out_dim, kernel_size=3, stride=2, padding="valid", bias=False
  271. )
  272. self.batchnorm = nn.BatchNorm2d(self.out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum)
  273. self.activation = ACT2FN[config.hidden_act]
  274. def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
  275. features = self.padding(pixel_values)
  276. features = self.convolution(features)
  277. features = self.batchnorm(features)
  278. features = self.activation(features)
  279. return features
  280. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseConv2d with EfficientNet->AlignVision
  281. class AlignVisionDepthwiseConv2d(nn.Conv2d):
  282. def __init__(
  283. self,
  284. in_channels,
  285. depth_multiplier=1,
  286. kernel_size=3,
  287. stride=1,
  288. padding=0,
  289. dilation=1,
  290. bias=True,
  291. padding_mode="zeros",
  292. ):
  293. out_channels = in_channels * depth_multiplier
  294. super().__init__(
  295. in_channels=in_channels,
  296. out_channels=out_channels,
  297. kernel_size=kernel_size,
  298. stride=stride,
  299. padding=padding,
  300. dilation=dilation,
  301. groups=in_channels,
  302. bias=bias,
  303. padding_mode=padding_mode,
  304. )
  305. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetExpansionLayer with EfficientNet->AlignVision
  306. class AlignVisionExpansionLayer(nn.Module):
  307. r"""
  308. This corresponds to the expansion phase of each block in the original implementation.
  309. """
  310. def __init__(self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int):
  311. super().__init__()
  312. self.expand_conv = nn.Conv2d(
  313. in_channels=in_dim,
  314. out_channels=out_dim,
  315. kernel_size=1,
  316. padding="same",
  317. bias=False,
  318. )
  319. self.expand_bn = nn.BatchNorm2d(num_features=out_dim, eps=config.batch_norm_eps)
  320. self.expand_act = ACT2FN[config.hidden_act]
  321. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  322. # Expand phase
  323. hidden_states = self.expand_conv(hidden_states)
  324. hidden_states = self.expand_bn(hidden_states)
  325. hidden_states = self.expand_act(hidden_states)
  326. return hidden_states
  327. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetDepthwiseLayer with EfficientNet->AlignVision
  328. class AlignVisionDepthwiseLayer(nn.Module):
  329. r"""
  330. This corresponds to the depthwise convolution phase of each block in the original implementation.
  331. """
  332. def __init__(
  333. self,
  334. config: AlignVisionConfig,
  335. in_dim: int,
  336. stride: int,
  337. kernel_size: int,
  338. adjust_padding: bool,
  339. ):
  340. super().__init__()
  341. self.stride = stride
  342. conv_pad = "valid" if self.stride == 2 else "same"
  343. padding = correct_pad(kernel_size, adjust=adjust_padding)
  344. self.depthwise_conv_pad = nn.ZeroPad2d(padding=padding)
  345. self.depthwise_conv = AlignVisionDepthwiseConv2d(
  346. in_dim, kernel_size=kernel_size, stride=stride, padding=conv_pad, bias=False
  347. )
  348. self.depthwise_norm = nn.BatchNorm2d(
  349. num_features=in_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  350. )
  351. self.depthwise_act = ACT2FN[config.hidden_act]
  352. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  353. # Depthwise convolution
  354. if self.stride == 2:
  355. hidden_states = self.depthwise_conv_pad(hidden_states)
  356. hidden_states = self.depthwise_conv(hidden_states)
  357. hidden_states = self.depthwise_norm(hidden_states)
  358. hidden_states = self.depthwise_act(hidden_states)
  359. return hidden_states
  360. # Copied from transformers.models.efficientnet.modeling_efficientnet.EfficientNetSqueezeExciteLayer with EfficientNet->AlignVision
  361. class AlignVisionSqueezeExciteLayer(nn.Module):
  362. r"""
  363. This corresponds to the Squeeze and Excitement phase of each block in the original implementation.
  364. """
  365. def __init__(self, config: AlignVisionConfig, in_dim: int, expand_dim: int, expand: bool = False):
  366. super().__init__()
  367. self.dim = expand_dim if expand else in_dim
  368. self.dim_se = max(1, int(in_dim * config.squeeze_expansion_ratio))
  369. self.squeeze = nn.AdaptiveAvgPool2d(output_size=1)
  370. self.reduce = nn.Conv2d(
  371. in_channels=self.dim,
  372. out_channels=self.dim_se,
  373. kernel_size=1,
  374. padding="same",
  375. )
  376. self.expand = nn.Conv2d(
  377. in_channels=self.dim_se,
  378. out_channels=self.dim,
  379. kernel_size=1,
  380. padding="same",
  381. )
  382. self.act_reduce = ACT2FN[config.hidden_act]
  383. self.act_expand = nn.Sigmoid()
  384. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  385. inputs = hidden_states
  386. hidden_states = self.squeeze(hidden_states)
  387. hidden_states = self.reduce(hidden_states)
  388. hidden_states = self.act_reduce(hidden_states)
  389. hidden_states = self.expand(hidden_states)
  390. hidden_states = self.act_expand(hidden_states)
  391. hidden_states = torch.mul(inputs, hidden_states)
  392. return hidden_states
  393. class AlignVisionFinalBlockLayer(nn.Module):
  394. r"""
  395. This corresponds to the final phase of each block in the original implementation.
  396. """
  397. def __init__(
  398. self, config: AlignVisionConfig, in_dim: int, out_dim: int, stride: int, drop_rate: float, id_skip: bool
  399. ):
  400. super().__init__()
  401. self.apply_dropout = stride == 1 and not id_skip
  402. self.project_conv = nn.Conv2d(
  403. in_channels=in_dim,
  404. out_channels=out_dim,
  405. kernel_size=1,
  406. padding="same",
  407. bias=False,
  408. )
  409. self.project_bn = nn.BatchNorm2d(
  410. num_features=out_dim, eps=config.batch_norm_eps, momentum=config.batch_norm_momentum
  411. )
  412. self.dropout = nn.Dropout(p=drop_rate)
  413. def forward(self, embeddings: torch.FloatTensor, hidden_states: torch.FloatTensor) -> torch.Tensor:
  414. hidden_states = self.project_conv(hidden_states)
  415. hidden_states = self.project_bn(hidden_states)
  416. if self.apply_dropout:
  417. hidden_states = self.dropout(hidden_states)
  418. hidden_states = hidden_states + embeddings
  419. return hidden_states
  420. class AlignVisionBlock(nn.Module):
  421. r"""
  422. This corresponds to the block module of original the EfficientNet vision encoder implementation.
  423. Args:
  424. config ([`AlignVisionConfig`]):
  425. Model configuration class.
  426. in_dim (`int`):
  427. Number of input channels.
  428. out_dim (`int`):
  429. Number of output channels.
  430. stride (`int`):
  431. Stride size to be used in convolution layers.
  432. expand_ratio (`int`):
  433. Expand ratio to set the output dimensions for the expansion and squeeze-excite layers.
  434. kernel_size (`int`):
  435. Kernel size for the depthwise convolution layer.
  436. drop_rate (`float`):
  437. Dropout rate to be used in the final phase of each block.
  438. id_skip (`bool`):
  439. Whether to apply dropout and sum the final hidden states with the input embeddings during the final phase
  440. of each block. Set to `True` for the first block of each stage.
  441. adjust_padding (`bool`):
  442. Whether to apply padding to only right and bottom side of the input kernel before the depthwise convolution
  443. operation, set to `True` for inputs with odd input sizes.
  444. """
  445. def __init__(
  446. self,
  447. config: AlignVisionConfig,
  448. in_dim: int,
  449. out_dim: int,
  450. stride: int,
  451. expand_ratio: int,
  452. kernel_size: int,
  453. drop_rate: float,
  454. id_skip: bool,
  455. adjust_padding: bool,
  456. ):
  457. super().__init__()
  458. self.expand_ratio = expand_ratio
  459. self.expand = True if self.expand_ratio != 1 else False
  460. expand_in_dim = in_dim * expand_ratio
  461. if self.expand:
  462. self.expansion = AlignVisionExpansionLayer(
  463. config=config, in_dim=in_dim, out_dim=expand_in_dim, stride=stride
  464. )
  465. self.depthwise_conv = AlignVisionDepthwiseLayer(
  466. config=config,
  467. in_dim=expand_in_dim if self.expand else in_dim,
  468. stride=stride,
  469. kernel_size=kernel_size,
  470. adjust_padding=adjust_padding,
  471. )
  472. self.squeeze_excite = AlignVisionSqueezeExciteLayer(
  473. config=config, in_dim=in_dim, expand_dim=expand_in_dim, expand=self.expand
  474. )
  475. self.projection = AlignVisionFinalBlockLayer(
  476. config=config,
  477. in_dim=expand_in_dim if self.expand else in_dim,
  478. out_dim=out_dim,
  479. stride=stride,
  480. drop_rate=drop_rate,
  481. id_skip=id_skip,
  482. )
  483. def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
  484. embeddings = hidden_states
  485. # Expansion and depthwise convolution phase
  486. if self.expand_ratio != 1:
  487. hidden_states = self.expansion(hidden_states)
  488. hidden_states = self.depthwise_conv(hidden_states)
  489. # Squeeze and excite phase
  490. hidden_states = self.squeeze_excite(hidden_states)
  491. hidden_states = self.projection(embeddings, hidden_states)
  492. return hidden_states
  493. class AlignVisionEncoder(nn.Module):
  494. r"""
  495. Forward propogates the embeddings through each vision encoder (EfficientNet) block.
  496. Args:
  497. config ([`AlignVisionConfig`]):
  498. Model configuration class.
  499. """
  500. def __init__(self, config: AlignVisionConfig):
  501. super().__init__()
  502. self.depth_coefficient = config.depth_coefficient
  503. def round_repeats(repeats):
  504. # Round number of block repeats based on depth multiplier.
  505. return int(math.ceil(self.depth_coefficient * repeats))
  506. num_base_blocks = len(config.in_channels)
  507. num_blocks = sum(round_repeats(n) for n in config.num_block_repeats)
  508. curr_block_num = 0
  509. blocks = []
  510. for i in range(num_base_blocks):
  511. in_dim = round_filters(config, config.in_channels[i])
  512. out_dim = round_filters(config, config.out_channels[i])
  513. stride = config.strides[i]
  514. kernel_size = config.kernel_sizes[i]
  515. expand_ratio = config.expand_ratios[i]
  516. for j in range(round_repeats(config.num_block_repeats[i])):
  517. id_skip = True if j == 0 else False
  518. stride = 1 if j > 0 else stride
  519. in_dim = out_dim if j > 0 else in_dim
  520. adjust_padding = False if curr_block_num in config.depthwise_padding else True
  521. drop_rate = config.drop_connect_rate * curr_block_num / num_blocks
  522. block = AlignVisionBlock(
  523. config=config,
  524. in_dim=in_dim,
  525. out_dim=out_dim,
  526. stride=stride,
  527. kernel_size=kernel_size,
  528. expand_ratio=expand_ratio,
  529. drop_rate=drop_rate,
  530. id_skip=id_skip,
  531. adjust_padding=adjust_padding,
  532. )
  533. blocks.append(block)
  534. curr_block_num += 1
  535. self.blocks = nn.ModuleList(blocks)
  536. def forward(
  537. self,
  538. hidden_states: torch.FloatTensor,
  539. output_hidden_states: Optional[bool] = False,
  540. return_dict: Optional[bool] = True,
  541. ) -> BaseModelOutputWithPoolingAndNoAttention:
  542. all_hidden_states = (hidden_states,) if output_hidden_states else None
  543. for block in self.blocks:
  544. hidden_states = block(hidden_states)
  545. if output_hidden_states:
  546. all_hidden_states += (hidden_states,)
  547. if not return_dict:
  548. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  549. return BaseModelOutputWithNoAttention(
  550. last_hidden_state=hidden_states,
  551. hidden_states=all_hidden_states,
  552. )
  553. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->AlignText
  554. class AlignTextEmbeddings(nn.Module):
  555. """Construct the embeddings from word, position and token_type embeddings."""
  556. def __init__(self, config):
  557. super().__init__()
  558. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  559. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  560. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  561. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  562. # any TensorFlow checkpoint file
  563. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  564. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  565. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  566. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  567. self.register_buffer(
  568. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  569. )
  570. self.register_buffer(
  571. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  572. )
  573. def forward(
  574. self,
  575. input_ids: Optional[torch.LongTensor] = None,
  576. token_type_ids: Optional[torch.LongTensor] = None,
  577. position_ids: Optional[torch.LongTensor] = None,
  578. inputs_embeds: Optional[torch.FloatTensor] = None,
  579. past_key_values_length: int = 0,
  580. ) -> torch.Tensor:
  581. if input_ids is not None:
  582. input_shape = input_ids.size()
  583. else:
  584. input_shape = inputs_embeds.size()[:-1]
  585. seq_length = input_shape[1]
  586. if position_ids is None:
  587. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  588. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  589. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  590. # issue #5664
  591. if token_type_ids is None:
  592. if hasattr(self, "token_type_ids"):
  593. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  594. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  595. token_type_ids = buffered_token_type_ids_expanded
  596. else:
  597. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  598. if inputs_embeds is None:
  599. inputs_embeds = self.word_embeddings(input_ids)
  600. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  601. embeddings = inputs_embeds + token_type_embeddings
  602. if self.position_embedding_type == "absolute":
  603. position_embeddings = self.position_embeddings(position_ids)
  604. embeddings += position_embeddings
  605. embeddings = self.LayerNorm(embeddings)
  606. embeddings = self.dropout(embeddings)
  607. return embeddings
  608. # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->AlignText
  609. class AlignTextSelfAttention(nn.Module):
  610. def __init__(self, config, position_embedding_type=None):
  611. super().__init__()
  612. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  613. raise ValueError(
  614. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  615. f"heads ({config.num_attention_heads})"
  616. )
  617. self.num_attention_heads = config.num_attention_heads
  618. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  619. self.all_head_size = self.num_attention_heads * self.attention_head_size
  620. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  621. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  622. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  623. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  624. self.position_embedding_type = position_embedding_type or getattr(
  625. config, "position_embedding_type", "absolute"
  626. )
  627. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  628. self.max_position_embeddings = config.max_position_embeddings
  629. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  630. self.is_decoder = config.is_decoder
  631. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  632. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  633. x = x.view(new_x_shape)
  634. return x.permute(0, 2, 1, 3)
  635. def forward(
  636. self,
  637. hidden_states: torch.Tensor,
  638. attention_mask: Optional[torch.FloatTensor] = None,
  639. head_mask: Optional[torch.FloatTensor] = None,
  640. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  641. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  642. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  643. output_attentions: Optional[bool] = False,
  644. ) -> Tuple[torch.Tensor]:
  645. mixed_query_layer = self.query(hidden_states)
  646. # If this is instantiated as a cross-attention module, the keys
  647. # and values come from an encoder; the attention mask needs to be
  648. # such that the encoder's padding tokens are not attended to.
  649. is_cross_attention = encoder_hidden_states is not None
  650. if is_cross_attention and past_key_value is not None:
  651. # reuse k,v, cross_attentions
  652. key_layer = past_key_value[0]
  653. value_layer = past_key_value[1]
  654. attention_mask = encoder_attention_mask
  655. elif is_cross_attention:
  656. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  657. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  658. attention_mask = encoder_attention_mask
  659. elif past_key_value is not None:
  660. key_layer = self.transpose_for_scores(self.key(hidden_states))
  661. value_layer = self.transpose_for_scores(self.value(hidden_states))
  662. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  663. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  664. else:
  665. key_layer = self.transpose_for_scores(self.key(hidden_states))
  666. value_layer = self.transpose_for_scores(self.value(hidden_states))
  667. query_layer = self.transpose_for_scores(mixed_query_layer)
  668. use_cache = past_key_value is not None
  669. if self.is_decoder:
  670. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  671. # Further calls to cross_attention layer can then reuse all cross-attention
  672. # key/value_states (first "if" case)
  673. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  674. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  675. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  676. # if encoder bi-directional self-attention `past_key_value` is always `None`
  677. past_key_value = (key_layer, value_layer)
  678. # Take the dot product between "query" and "key" to get the raw attention scores.
  679. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  680. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  681. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  682. if use_cache:
  683. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  684. -1, 1
  685. )
  686. else:
  687. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  688. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  689. distance = position_ids_l - position_ids_r
  690. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  691. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  692. if self.position_embedding_type == "relative_key":
  693. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  694. attention_scores = attention_scores + relative_position_scores
  695. elif self.position_embedding_type == "relative_key_query":
  696. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  697. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  698. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  699. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  700. if attention_mask is not None:
  701. # Apply the attention mask is (precomputed for all layers in AlignTextModel forward() function)
  702. attention_scores = attention_scores + attention_mask
  703. # Normalize the attention scores to probabilities.
  704. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  705. # This is actually dropping out entire tokens to attend to, which might
  706. # seem a bit unusual, but is taken from the original Transformer paper.
  707. attention_probs = self.dropout(attention_probs)
  708. # Mask heads if we want to
  709. if head_mask is not None:
  710. attention_probs = attention_probs * head_mask
  711. context_layer = torch.matmul(attention_probs, value_layer)
  712. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  713. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  714. context_layer = context_layer.view(new_context_layer_shape)
  715. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  716. if self.is_decoder:
  717. outputs = outputs + (past_key_value,)
  718. return outputs
  719. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->AlignText
  720. class AlignTextSelfOutput(nn.Module):
  721. def __init__(self, config):
  722. super().__init__()
  723. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  724. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  725. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  726. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  727. hidden_states = self.dense(hidden_states)
  728. hidden_states = self.dropout(hidden_states)
  729. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  730. return hidden_states
  731. ALIGN_TEXT_SELF_ATTENTION_CLASSES = {
  732. "eager": AlignTextSelfAttention,
  733. }
  734. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->AlignText,BERT->ALIGN_TEXT
  735. class AlignTextAttention(nn.Module):
  736. def __init__(self, config, position_embedding_type=None):
  737. super().__init__()
  738. self.self = ALIGN_TEXT_SELF_ATTENTION_CLASSES[config._attn_implementation](
  739. config, position_embedding_type=position_embedding_type
  740. )
  741. self.output = AlignTextSelfOutput(config)
  742. self.pruned_heads = set()
  743. def prune_heads(self, heads):
  744. if len(heads) == 0:
  745. return
  746. heads, index = find_pruneable_heads_and_indices(
  747. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  748. )
  749. # Prune linear layers
  750. self.self.query = prune_linear_layer(self.self.query, index)
  751. self.self.key = prune_linear_layer(self.self.key, index)
  752. self.self.value = prune_linear_layer(self.self.value, index)
  753. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  754. # Update hyper params and store pruned heads
  755. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  756. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  757. self.pruned_heads = self.pruned_heads.union(heads)
  758. def forward(
  759. self,
  760. hidden_states: torch.Tensor,
  761. attention_mask: Optional[torch.FloatTensor] = None,
  762. head_mask: Optional[torch.FloatTensor] = None,
  763. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  764. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  765. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  766. output_attentions: Optional[bool] = False,
  767. ) -> Tuple[torch.Tensor]:
  768. self_outputs = self.self(
  769. hidden_states,
  770. attention_mask,
  771. head_mask,
  772. encoder_hidden_states,
  773. encoder_attention_mask,
  774. past_key_value,
  775. output_attentions,
  776. )
  777. attention_output = self.output(self_outputs[0], hidden_states)
  778. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  779. return outputs
  780. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->AlignText
  781. class AlignTextIntermediate(nn.Module):
  782. def __init__(self, config):
  783. super().__init__()
  784. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  785. if isinstance(config.hidden_act, str):
  786. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  787. else:
  788. self.intermediate_act_fn = config.hidden_act
  789. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  790. hidden_states = self.dense(hidden_states)
  791. hidden_states = self.intermediate_act_fn(hidden_states)
  792. return hidden_states
  793. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->AlignText
  794. class AlignTextOutput(nn.Module):
  795. def __init__(self, config):
  796. super().__init__()
  797. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  798. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  799. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  800. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  801. hidden_states = self.dense(hidden_states)
  802. hidden_states = self.dropout(hidden_states)
  803. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  804. return hidden_states
  805. # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->AlignText
  806. class AlignTextLayer(nn.Module):
  807. def __init__(self, config):
  808. super().__init__()
  809. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  810. self.seq_len_dim = 1
  811. self.attention = AlignTextAttention(config)
  812. self.is_decoder = config.is_decoder
  813. self.add_cross_attention = config.add_cross_attention
  814. if self.add_cross_attention:
  815. if not self.is_decoder:
  816. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  817. self.crossattention = AlignTextAttention(config, position_embedding_type="absolute")
  818. self.intermediate = AlignTextIntermediate(config)
  819. self.output = AlignTextOutput(config)
  820. def forward(
  821. self,
  822. hidden_states: torch.Tensor,
  823. attention_mask: Optional[torch.FloatTensor] = None,
  824. head_mask: Optional[torch.FloatTensor] = None,
  825. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  826. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  827. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  828. output_attentions: Optional[bool] = False,
  829. ) -> Tuple[torch.Tensor]:
  830. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  831. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  832. self_attention_outputs = self.attention(
  833. hidden_states,
  834. attention_mask,
  835. head_mask,
  836. output_attentions=output_attentions,
  837. past_key_value=self_attn_past_key_value,
  838. )
  839. attention_output = self_attention_outputs[0]
  840. # if decoder, the last output is tuple of self-attn cache
  841. if self.is_decoder:
  842. outputs = self_attention_outputs[1:-1]
  843. present_key_value = self_attention_outputs[-1]
  844. else:
  845. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  846. cross_attn_present_key_value = None
  847. if self.is_decoder and encoder_hidden_states is not None:
  848. if not hasattr(self, "crossattention"):
  849. raise ValueError(
  850. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  851. " by setting `config.add_cross_attention=True`"
  852. )
  853. # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
  854. cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
  855. cross_attention_outputs = self.crossattention(
  856. attention_output,
  857. attention_mask,
  858. head_mask,
  859. encoder_hidden_states,
  860. encoder_attention_mask,
  861. cross_attn_past_key_value,
  862. output_attentions,
  863. )
  864. attention_output = cross_attention_outputs[0]
  865. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  866. # add cross-attn cache to positions 3,4 of present_key_value tuple
  867. cross_attn_present_key_value = cross_attention_outputs[-1]
  868. present_key_value = present_key_value + cross_attn_present_key_value
  869. layer_output = apply_chunking_to_forward(
  870. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  871. )
  872. outputs = (layer_output,) + outputs
  873. # if decoder, return the attn key/values as the last output
  874. if self.is_decoder:
  875. outputs = outputs + (present_key_value,)
  876. return outputs
  877. def feed_forward_chunk(self, attention_output):
  878. intermediate_output = self.intermediate(attention_output)
  879. layer_output = self.output(intermediate_output, attention_output)
  880. return layer_output
  881. # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->AlignText
  882. class AlignTextEncoder(nn.Module):
  883. def __init__(self, config):
  884. super().__init__()
  885. self.config = config
  886. self.layer = nn.ModuleList([AlignTextLayer(config) for _ in range(config.num_hidden_layers)])
  887. self.gradient_checkpointing = False
  888. def forward(
  889. self,
  890. hidden_states: torch.Tensor,
  891. attention_mask: Optional[torch.FloatTensor] = None,
  892. head_mask: Optional[torch.FloatTensor] = None,
  893. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  894. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  895. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  896. use_cache: Optional[bool] = None,
  897. output_attentions: Optional[bool] = False,
  898. output_hidden_states: Optional[bool] = False,
  899. return_dict: Optional[bool] = True,
  900. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  901. all_hidden_states = () if output_hidden_states else None
  902. all_self_attentions = () if output_attentions else None
  903. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  904. if self.gradient_checkpointing and self.training:
  905. if use_cache:
  906. logger.warning_once(
  907. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  908. )
  909. use_cache = False
  910. next_decoder_cache = () if use_cache else None
  911. for i, layer_module in enumerate(self.layer):
  912. if output_hidden_states:
  913. all_hidden_states = all_hidden_states + (hidden_states,)
  914. layer_head_mask = head_mask[i] if head_mask is not None else None
  915. past_key_value = past_key_values[i] if past_key_values is not None else None
  916. if self.gradient_checkpointing and self.training:
  917. layer_outputs = self._gradient_checkpointing_func(
  918. layer_module.__call__,
  919. hidden_states,
  920. attention_mask,
  921. layer_head_mask,
  922. encoder_hidden_states,
  923. encoder_attention_mask,
  924. past_key_value,
  925. output_attentions,
  926. )
  927. else:
  928. layer_outputs = layer_module(
  929. hidden_states,
  930. attention_mask,
  931. layer_head_mask,
  932. encoder_hidden_states,
  933. encoder_attention_mask,
  934. past_key_value,
  935. output_attentions,
  936. )
  937. hidden_states = layer_outputs[0]
  938. if use_cache:
  939. next_decoder_cache += (layer_outputs[-1],)
  940. if output_attentions:
  941. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  942. if self.config.add_cross_attention:
  943. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  944. if output_hidden_states:
  945. all_hidden_states = all_hidden_states + (hidden_states,)
  946. if not return_dict:
  947. return tuple(
  948. v
  949. for v in [
  950. hidden_states,
  951. next_decoder_cache,
  952. all_hidden_states,
  953. all_self_attentions,
  954. all_cross_attentions,
  955. ]
  956. if v is not None
  957. )
  958. return BaseModelOutputWithPastAndCrossAttentions(
  959. last_hidden_state=hidden_states,
  960. past_key_values=next_decoder_cache,
  961. hidden_states=all_hidden_states,
  962. attentions=all_self_attentions,
  963. cross_attentions=all_cross_attentions,
  964. )
  965. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert -> AlignText
  966. class AlignTextPooler(nn.Module):
  967. def __init__(self, config):
  968. super().__init__()
  969. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  970. self.activation = nn.Tanh()
  971. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  972. # We "pool" the model by simply taking the hidden state corresponding
  973. # to the first token.
  974. first_token_tensor = hidden_states[:, 0]
  975. pooled_output = self.dense(first_token_tensor)
  976. pooled_output = self.activation(pooled_output)
  977. return pooled_output
  978. class AlignPreTrainedModel(PreTrainedModel):
  979. """
  980. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  981. models.
  982. """
  983. config_class = AlignConfig
  984. base_model_prefix = "align"
  985. supports_gradient_checkpointing = True
  986. def _init_weights(self, module):
  987. """Initialize the weights"""
  988. if isinstance(module, (nn.Linear, nn.Conv2d)):
  989. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  990. if module.bias is not None:
  991. module.bias.data.zero_()
  992. elif isinstance(module, AlignModel):
  993. nn.init.xavier_uniform_(module.text_projection.weight)
  994. module.text_projection.bias.data.zero_()
  995. module.text_projection._is_hf_initialized = True
  996. elif isinstance(module, nn.Embedding):
  997. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  998. if module.padding_idx is not None:
  999. module.weight.data[module.padding_idx].zero_()
  1000. if isinstance(module, nn.LayerNorm):
  1001. module.bias.data.zero_()
  1002. module.weight.data.fill_(1.0)
  1003. @add_start_docstrings(
  1004. """The text model from ALIGN without any head or projection on top.""",
  1005. ALIGN_START_DOCSTRING,
  1006. )
  1007. class AlignTextModel(AlignPreTrainedModel):
  1008. config_class = AlignTextConfig
  1009. _no_split_modules = ["AlignTextEmbeddings"]
  1010. def __init__(self, config: AlignTextConfig, add_pooling_layer: bool = True):
  1011. super().__init__(config)
  1012. self.config = config
  1013. self.embeddings = AlignTextEmbeddings(config)
  1014. self.encoder = AlignTextEncoder(config)
  1015. self.pooler = AlignTextPooler(config) if add_pooling_layer else None
  1016. # Initialize weights and apply final processing
  1017. self.post_init()
  1018. def get_input_embeddings(self):
  1019. return self.embeddings.word_embeddings
  1020. def set_input_embeddings(self, value):
  1021. self.embeddings.word_embeddings = value
  1022. @add_start_docstrings_to_model_forward(ALIGN_TEXT_INPUTS_DOCSTRING)
  1023. @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=AlignTextConfig)
  1024. def forward(
  1025. self,
  1026. input_ids: Optional[torch.Tensor] = None,
  1027. attention_mask: Optional[torch.Tensor] = None,
  1028. token_type_ids: Optional[torch.Tensor] = None,
  1029. position_ids: Optional[torch.Tensor] = None,
  1030. head_mask: Optional[torch.Tensor] = None,
  1031. inputs_embeds: Optional[torch.Tensor] = None,
  1032. output_attentions: Optional[bool] = None,
  1033. output_hidden_states: Optional[bool] = None,
  1034. return_dict: Optional[bool] = None,
  1035. ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
  1036. r"""
  1037. Returns:
  1038. Examples:
  1039. ```python
  1040. >>> from transformers import AutoTokenizer, AlignTextModel
  1041. >>> model = AlignTextModel.from_pretrained("kakaobrain/align-base")
  1042. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  1043. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  1044. >>> outputs = model(**inputs)
  1045. >>> last_hidden_state = outputs.last_hidden_state
  1046. >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
  1047. ```"""
  1048. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1049. output_hidden_states = (
  1050. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1051. )
  1052. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1053. if input_ids is not None and inputs_embeds is not None:
  1054. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  1055. elif input_ids is not None:
  1056. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  1057. input_shape = input_ids.size()
  1058. elif inputs_embeds is not None:
  1059. input_shape = inputs_embeds.size()[:-1]
  1060. else:
  1061. raise ValueError("You have to specify either input_ids or inputs_embeds")
  1062. batch_size, seq_length = input_shape
  1063. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1064. if attention_mask is None:
  1065. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  1066. if token_type_ids is None:
  1067. if hasattr(self.embeddings, "token_type_ids"):
  1068. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  1069. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  1070. token_type_ids = buffered_token_type_ids_expanded
  1071. else:
  1072. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  1073. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  1074. # ourselves in which case we just need to make it broadcastable to all heads.
  1075. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  1076. # Prepare head mask if needed
  1077. # 1.0 in head_mask indicate we keep the head
  1078. # attention_probs has shape bsz x n_heads x N x N
  1079. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  1080. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  1081. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  1082. embedding_output = self.embeddings(
  1083. input_ids=input_ids,
  1084. position_ids=position_ids,
  1085. token_type_ids=token_type_ids,
  1086. inputs_embeds=inputs_embeds,
  1087. )
  1088. encoder_outputs = self.encoder(
  1089. embedding_output,
  1090. attention_mask=extended_attention_mask,
  1091. head_mask=head_mask,
  1092. output_attentions=output_attentions,
  1093. output_hidden_states=output_hidden_states,
  1094. return_dict=return_dict,
  1095. )
  1096. sequence_output = encoder_outputs[0]
  1097. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  1098. if not return_dict:
  1099. return (sequence_output, pooled_output) + encoder_outputs[1:]
  1100. return BaseModelOutputWithPoolingAndCrossAttentions(
  1101. last_hidden_state=sequence_output,
  1102. pooler_output=pooled_output,
  1103. hidden_states=encoder_outputs.hidden_states,
  1104. attentions=encoder_outputs.attentions,
  1105. cross_attentions=encoder_outputs.cross_attentions,
  1106. )
  1107. @add_start_docstrings(
  1108. """The vision model from ALIGN without any head or projection on top.""",
  1109. ALIGN_START_DOCSTRING,
  1110. )
  1111. class AlignVisionModel(AlignPreTrainedModel):
  1112. config_class = AlignVisionConfig
  1113. main_input_name = "pixel_values"
  1114. supports_gradient_checkpointing = False
  1115. def __init__(self, config: AlignVisionConfig):
  1116. super().__init__(config)
  1117. self.config = config
  1118. self.embeddings = AlignVisionEmbeddings(config)
  1119. self.encoder = AlignVisionEncoder(config)
  1120. # Final pooling layer
  1121. if config.pooling_type == "mean":
  1122. self.pooler = nn.AvgPool2d(config.hidden_dim, ceil_mode=True)
  1123. elif config.pooling_type == "max":
  1124. self.pooler = nn.MaxPool2d(config.hidden_dim, ceil_mode=True)
  1125. else:
  1126. raise ValueError(f"config.pooling must be one of ['mean', 'max'] got {config.pooling}")
  1127. # Initialize weights and apply final processing
  1128. self.post_init()
  1129. def get_input_embeddings(self) -> nn.Module:
  1130. return self.vision_model.embeddings.convolution
  1131. @add_start_docstrings_to_model_forward(ALIGN_VISION_INPUTS_DOCSTRING)
  1132. @replace_return_docstrings(output_type=BaseModelOutputWithPoolingAndNoAttention, config_class=AlignVisionConfig)
  1133. def forward(
  1134. self,
  1135. pixel_values: Optional[torch.FloatTensor] = None,
  1136. output_hidden_states: Optional[bool] = None,
  1137. return_dict: Optional[bool] = None,
  1138. ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
  1139. r"""
  1140. Returns:
  1141. Examples:
  1142. ```python
  1143. >>> from PIL import Image
  1144. >>> import requests
  1145. >>> from transformers import AutoProcessor, AlignVisionModel
  1146. >>> model = AlignVisionModel.from_pretrained("kakaobrain/align-base")
  1147. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  1148. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1149. >>> image = Image.open(requests.get(url, stream=True).raw)
  1150. >>> inputs = processor(images=image, return_tensors="pt")
  1151. >>> outputs = model(**inputs)
  1152. >>> last_hidden_state = outputs.last_hidden_state
  1153. >>> pooled_output = outputs.pooler_output # pooled CLS states
  1154. ```"""
  1155. output_hidden_states = (
  1156. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1157. )
  1158. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1159. if pixel_values is None:
  1160. raise ValueError("You have to specify pixel_values")
  1161. embedding_output = self.embeddings(pixel_values)
  1162. encoder_outputs = self.encoder(
  1163. embedding_output,
  1164. output_hidden_states=output_hidden_states,
  1165. return_dict=return_dict,
  1166. )
  1167. # Apply pooling
  1168. last_hidden_state = encoder_outputs[0]
  1169. pooled_output = self.pooler(last_hidden_state)
  1170. # Reshape (batch_size, projection_dim, 1 , 1) -> (batch_size, projection_dim)
  1171. pooled_output = pooled_output.reshape(pooled_output.shape[:2])
  1172. if not return_dict:
  1173. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  1174. return BaseModelOutputWithPoolingAndNoAttention(
  1175. last_hidden_state=last_hidden_state,
  1176. pooler_output=pooled_output,
  1177. hidden_states=encoder_outputs.hidden_states,
  1178. )
  1179. @add_start_docstrings(ALIGN_START_DOCSTRING)
  1180. class AlignModel(AlignPreTrainedModel):
  1181. config_class = AlignConfig
  1182. def __init__(self, config: AlignConfig):
  1183. super().__init__(config)
  1184. if not isinstance(config.text_config, AlignTextConfig):
  1185. raise TypeError(
  1186. "config.text_config is expected to be of type AlignTextConfig but is of type"
  1187. f" {type(config.text_config)}."
  1188. )
  1189. if not isinstance(config.vision_config, AlignVisionConfig):
  1190. raise TypeError(
  1191. "config.vision_config is expected to be of type AlignVisionConfig but is of type"
  1192. f" {type(config.vision_config)}."
  1193. )
  1194. text_config = config.text_config
  1195. vision_config = config.vision_config
  1196. self.projection_dim = config.projection_dim
  1197. self.text_embed_dim = text_config.hidden_size
  1198. self.text_model = AlignTextModel(text_config)
  1199. self.vision_model = AlignVisionModel(vision_config)
  1200. self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim)
  1201. self.temperature = nn.Parameter(torch.tensor(self.config.temperature_init_value))
  1202. # Initialize weights and apply final processing
  1203. self.post_init()
  1204. @add_start_docstrings_to_model_forward(ALIGN_TEXT_INPUTS_DOCSTRING)
  1205. def get_text_features(
  1206. self,
  1207. input_ids: Optional[torch.Tensor] = None,
  1208. attention_mask: Optional[torch.Tensor] = None,
  1209. token_type_ids: Optional[torch.Tensor] = None,
  1210. position_ids: Optional[torch.Tensor] = None,
  1211. head_mask: Optional[torch.Tensor] = None,
  1212. inputs_embeds: Optional[torch.Tensor] = None,
  1213. output_attentions: Optional[bool] = None,
  1214. output_hidden_states: Optional[bool] = None,
  1215. return_dict: Optional[bool] = None,
  1216. ) -> torch.FloatTensor:
  1217. r"""
  1218. Returns:
  1219. text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
  1220. applying the projection layer to the pooled output of [`AlignTextModel`].
  1221. Examples:
  1222. ```python
  1223. >>> from transformers import AutoTokenizer, AlignModel
  1224. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  1225. >>> tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  1226. >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
  1227. >>> text_features = model.get_text_features(**inputs)
  1228. ```"""
  1229. # Use ALIGN model's config for some fields (if specified) instead of those of vision & text components.
  1230. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1231. output_hidden_states = (
  1232. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1233. )
  1234. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1235. text_outputs = self.text_model(
  1236. input_ids=input_ids,
  1237. attention_mask=attention_mask,
  1238. token_type_ids=token_type_ids,
  1239. position_ids=position_ids,
  1240. head_mask=head_mask,
  1241. inputs_embeds=inputs_embeds,
  1242. output_attentions=output_attentions,
  1243. output_hidden_states=output_hidden_states,
  1244. return_dict=return_dict,
  1245. )
  1246. last_hidden_state = text_outputs[0][:, 0, :]
  1247. text_features = self.text_projection(last_hidden_state)
  1248. return text_features
  1249. @add_start_docstrings_to_model_forward(ALIGN_VISION_INPUTS_DOCSTRING)
  1250. def get_image_features(
  1251. self,
  1252. pixel_values: Optional[torch.FloatTensor] = None,
  1253. output_hidden_states: Optional[bool] = None,
  1254. return_dict: Optional[bool] = None,
  1255. ) -> torch.FloatTensor:
  1256. r"""
  1257. Returns:
  1258. image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
  1259. applying the projection layer to the pooled output of [`AlignVisionModel`].
  1260. Examples:
  1261. ```python
  1262. >>> from PIL import Image
  1263. >>> import requests
  1264. >>> from transformers import AutoProcessor, AlignModel
  1265. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  1266. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  1267. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1268. >>> image = Image.open(requests.get(url, stream=True).raw)
  1269. >>> inputs = processor(images=image, return_tensors="pt")
  1270. >>> image_features = model.get_image_features(**inputs)
  1271. ```"""
  1272. # Use ALIGN model's config for some fields (if specified) instead of those of vision & text components.
  1273. output_hidden_states = (
  1274. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1275. )
  1276. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1277. vision_outputs = self.vision_model(
  1278. pixel_values=pixel_values,
  1279. output_hidden_states=output_hidden_states,
  1280. return_dict=return_dict,
  1281. )
  1282. image_features = vision_outputs[1] # pooled_output
  1283. return image_features
  1284. @add_start_docstrings_to_model_forward(ALIGN_INPUTS_DOCSTRING)
  1285. @replace_return_docstrings(output_type=AlignOutput, config_class=AlignConfig)
  1286. def forward(
  1287. self,
  1288. input_ids: Optional[torch.LongTensor] = None,
  1289. pixel_values: Optional[torch.FloatTensor] = None,
  1290. attention_mask: Optional[torch.Tensor] = None,
  1291. token_type_ids: Optional[torch.Tensor] = None,
  1292. position_ids: Optional[torch.Tensor] = None,
  1293. head_mask: Optional[torch.Tensor] = None,
  1294. inputs_embeds: Optional[torch.Tensor] = None,
  1295. return_loss: Optional[bool] = None,
  1296. output_attentions: Optional[bool] = None,
  1297. output_hidden_states: Optional[bool] = None,
  1298. return_dict: Optional[bool] = None,
  1299. ) -> Union[Tuple, AlignOutput]:
  1300. r"""
  1301. Returns:
  1302. Examples:
  1303. ```python
  1304. >>> from PIL import Image
  1305. >>> import requests
  1306. >>> from transformers import AutoProcessor, AlignModel
  1307. >>> model = AlignModel.from_pretrained("kakaobrain/align-base")
  1308. >>> processor = AutoProcessor.from_pretrained("kakaobrain/align-base")
  1309. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1310. >>> image = Image.open(requests.get(url, stream=True).raw)
  1311. >>> inputs = processor(
  1312. ... images=image, text=["a photo of a cat", "a photo of a dog"], return_tensors="pt", padding=True
  1313. ... )
  1314. >>> outputs = model(**inputs)
  1315. >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
  1316. >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
  1317. ```"""
  1318. # Use ALIGN model's config for some fields (if specified) instead of those of vision & text components.
  1319. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1320. output_hidden_states = (
  1321. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1322. )
  1323. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1324. vision_outputs = self.vision_model(
  1325. pixel_values=pixel_values,
  1326. output_hidden_states=output_hidden_states,
  1327. return_dict=return_dict,
  1328. )
  1329. text_outputs = self.text_model(
  1330. input_ids=input_ids,
  1331. attention_mask=attention_mask,
  1332. token_type_ids=token_type_ids,
  1333. position_ids=position_ids,
  1334. head_mask=head_mask,
  1335. inputs_embeds=inputs_embeds,
  1336. output_attentions=output_attentions,
  1337. output_hidden_states=output_hidden_states,
  1338. return_dict=return_dict,
  1339. )
  1340. image_embeds = vision_outputs[1]
  1341. text_embeds = text_outputs[0][:, 0, :]
  1342. text_embeds = self.text_projection(text_embeds)
  1343. # normalized features
  1344. image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
  1345. text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
  1346. # cosine similarity as logits
  1347. logits_per_text = torch.matmul(text_embeds, image_embeds.t()) / self.temperature
  1348. logits_per_image = logits_per_text.t()
  1349. loss = None
  1350. if return_loss:
  1351. loss = align_loss(logits_per_text)
  1352. if not return_dict:
  1353. output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
  1354. return ((loss,) + output) if loss is not None else output
  1355. return AlignOutput(
  1356. loss=loss,
  1357. logits_per_image=logits_per_image,
  1358. logits_per_text=logits_per_text,
  1359. text_embeds=text_embeds,
  1360. image_embeds=image_embeds,
  1361. text_model_output=text_outputs,
  1362. vision_model_output=vision_outputs,
  1363. )
  1364. __all__ = ["AlignPreTrainedModel", "AlignTextModel", "AlignVisionModel", "AlignModel"]