modeling_data2vec_vision.py 59 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Data2VecVision model."""
  16. import collections.abc
  17. import math
  18. from dataclasses import dataclass
  19. from typing import List, Optional, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. ImageClassifierOutput,
  29. SemanticSegmenterOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import (
  34. add_code_sample_docstrings,
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. logging,
  38. replace_return_docstrings,
  39. torch_int,
  40. )
  41. from .configuration_data2vec_vision import Data2VecVisionConfig
  42. logger = logging.get_logger(__name__)
  43. # General docstring
  44. _CONFIG_FOR_DOC = "Data2VecVisionConfig"
  45. # Base docstring
  46. _CHECKPOINT_FOR_DOC = "facebook/data2vec-vision-base"
  47. _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
  48. # Image classification docstring
  49. _IMAGE_CLASS_CHECKPOINT = "facebook/data2vec-vision-base-ft1k"
  50. _IMAGE_CLASS_EXPECTED_OUTPUT = "remote control, remote"
  51. @dataclass
  52. # Copied from transformers.models.beit.modeling_beit.BeitModelOutputWithPooling with Beit->Data2VecVision
  53. class Data2VecVisionModelOutputWithPooling(BaseModelOutputWithPooling):
  54. """
  55. Class for outputs of [`Data2VecVisionModel`].
  56. Args:
  57. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  58. Sequence of hidden-states at the output of the last layer of the model.
  59. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`):
  60. Average of the last layer hidden states of the patch tokens (excluding the *[CLS]* token) if
  61. *config.use_mean_pooling* is set to True. If set to False, then the final hidden state of the *[CLS]* token
  62. will be returned.
  63. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  64. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  65. shape `(batch_size, sequence_length, hidden_size)`.
  66. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  67. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  68. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  69. sequence_length)`.
  70. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  71. heads.
  72. """
  73. # Copied from transformers.models.beit.modeling_beit.drop_path
  74. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  75. """
  76. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  77. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  78. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  79. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  80. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  81. argument.
  82. """
  83. if drop_prob == 0.0 or not training:
  84. return input
  85. keep_prob = 1 - drop_prob
  86. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  87. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  88. random_tensor.floor_() # binarize
  89. output = input.div(keep_prob) * random_tensor
  90. return output
  91. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Data2VecVision
  92. class Data2VecVisionDropPath(nn.Module):
  93. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  94. def __init__(self, drop_prob: Optional[float] = None) -> None:
  95. super().__init__()
  96. self.drop_prob = drop_prob
  97. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  98. return drop_path(hidden_states, self.drop_prob, self.training)
  99. def extra_repr(self) -> str:
  100. return "p={}".format(self.drop_prob)
  101. # Copied from transformers.models.beit.modeling_beit.BeitEmbeddings with Beit->Data2VecVision
  102. class Data2VecVisionEmbeddings(nn.Module):
  103. """
  104. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  105. """
  106. def __init__(self, config: Data2VecVisionConfig) -> None:
  107. super().__init__()
  108. self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  109. if config.use_mask_token:
  110. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  111. else:
  112. self.mask_token = None
  113. self.patch_embeddings = Data2VecVisionPatchEmbeddings(config)
  114. self.patch_size = config.patch_size
  115. self.image_size = (
  116. config.image_size
  117. if isinstance(config.image_size, collections.abc.Iterable)
  118. else (config.image_size, config.image_size)
  119. )
  120. num_patches = self.patch_embeddings.num_patches
  121. if config.use_absolute_position_embeddings:
  122. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
  123. else:
  124. self.position_embeddings = None
  125. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  126. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  127. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  128. """
  129. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  130. images. This method is also adapted to support torch.jit tracing.
  131. Adapted from:
  132. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  133. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  134. """
  135. num_patches = embeddings.shape[1] - 1
  136. num_positions = self.position_embeddings.shape[1] - 1
  137. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  138. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  139. return self.position_embeddings
  140. class_pos_embed = self.position_embeddings[:, :1]
  141. patch_pos_embed = self.position_embeddings[:, 1:]
  142. dim = embeddings.shape[-1]
  143. new_height = height // self.patch_size
  144. new_width = width // self.patch_size
  145. sqrt_num_positions = torch_int(num_positions**0.5)
  146. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  147. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  148. patch_pos_embed = nn.functional.interpolate(
  149. patch_pos_embed,
  150. size=(new_height, new_width),
  151. mode="bicubic",
  152. align_corners=False,
  153. )
  154. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  155. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  156. def forward(
  157. self,
  158. pixel_values: torch.Tensor,
  159. bool_masked_pos: Optional[torch.BoolTensor] = None,
  160. interpolate_pos_encoding: bool = False,
  161. ) -> torch.Tensor:
  162. _, _, height, width = pixel_values.shape
  163. embeddings, (patch_height, patch_width) = self.patch_embeddings(
  164. pixel_values, self.position_embeddings[:, 1:, :] if self.position_embeddings is not None else None
  165. )
  166. batch_size, seq_len, _ = embeddings.size()
  167. if bool_masked_pos is not None:
  168. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  169. # replace the masked visual tokens by mask_tokens
  170. w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  171. embeddings = embeddings * (1 - w) + mask_tokens * w
  172. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  173. if self.position_embeddings is not None:
  174. if interpolate_pos_encoding:
  175. cls_tokens = cls_tokens + self.interpolate_pos_encoding(embeddings, height, width)
  176. else:
  177. cls_tokens = cls_tokens + self.position_embeddings[:, :1, :]
  178. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  179. embeddings = self.dropout(embeddings)
  180. return embeddings, (patch_height, patch_width)
  181. # Copied from transformers.models.beit.modeling_beit.BeitPatchEmbeddings with Beit->Data2VecVision
  182. class Data2VecVisionPatchEmbeddings(nn.Module):
  183. """
  184. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  185. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  186. Transformer.
  187. """
  188. def __init__(self, config):
  189. super().__init__()
  190. image_size, patch_size = config.image_size, config.patch_size
  191. num_channels, hidden_size = config.num_channels, config.hidden_size
  192. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  193. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  194. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  195. patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  196. self.image_size = image_size
  197. self.patch_size = patch_size
  198. self.num_channels = num_channels
  199. self.num_patches = num_patches
  200. self.patch_shape = patch_shape
  201. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  202. def forward(
  203. self,
  204. pixel_values: torch.Tensor,
  205. position_embedding: Optional[torch.Tensor] = None,
  206. ) -> torch.Tensor:
  207. batch_size, num_channels, height, width = pixel_values.shape
  208. if num_channels != self.num_channels:
  209. raise ValueError(
  210. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  211. )
  212. embeddings = self.projection(pixel_values)
  213. patch_height, patch_width = embeddings.shape[2], embeddings.shape[3]
  214. if position_embedding is not None:
  215. # interpolate the position embedding to the corresponding size
  216. position_embedding = position_embedding.view(1, self.patch_shape[0], self.patch_shape[1], -1).permute(
  217. 0, 3, 1, 2
  218. )
  219. position_embedding = nn.functional.interpolate(
  220. position_embedding, size=(patch_height, patch_width), mode="bicubic"
  221. )
  222. embeddings = embeddings + position_embedding
  223. embeddings = embeddings.flatten(2).transpose(1, 2)
  224. return embeddings, (patch_height, patch_width)
  225. # Copied from transformers.models.beit.modeling_beit.BeitSelfAttention with Beit->Data2VecVision
  226. class Data2VecVisionSelfAttention(nn.Module):
  227. def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
  228. super().__init__()
  229. self.config = config
  230. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  231. raise ValueError(
  232. f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
  233. f"heads {config.num_attention_heads}."
  234. )
  235. self.num_attention_heads = config.num_attention_heads
  236. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  237. self.all_head_size = self.num_attention_heads * self.attention_head_size
  238. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  239. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  240. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  241. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  242. if window_size:
  243. self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
  244. else:
  245. self.relative_position_bias = None
  246. def transpose_for_scores(self, x):
  247. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  248. x = x.view(*new_x_shape)
  249. return x.permute(0, 2, 1, 3)
  250. def forward(
  251. self,
  252. hidden_states: torch.Tensor,
  253. head_mask: Optional[torch.Tensor] = None,
  254. output_attentions: bool = False,
  255. relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
  256. interpolate_pos_encoding: bool = False,
  257. resolution: Optional[Tuple[int]] = None,
  258. ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
  259. mixed_query_layer = self.query(hidden_states)
  260. key_layer = self.transpose_for_scores(self.key(hidden_states))
  261. value_layer = self.transpose_for_scores(self.value(hidden_states))
  262. query_layer = self.transpose_for_scores(mixed_query_layer)
  263. # Take the dot product between "query" and "key" to get the raw attention scores.
  264. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  265. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  266. # Add relative position bias if present.
  267. if self.relative_position_bias is not None:
  268. height, width = resolution
  269. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  270. attention_scores = attention_scores + self.relative_position_bias(
  271. window_size, interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  272. )
  273. # Add shared relative position bias if provided.
  274. if relative_position_bias is not None:
  275. attention_scores = attention_scores + relative_position_bias
  276. # Normalize the attention scores to probabilities.
  277. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  278. # This is actually dropping out entire tokens to attend to, which might
  279. # seem a bit unusual, but is taken from the original Transformer paper.
  280. attention_probs = self.dropout(attention_probs)
  281. # Mask heads if we want to
  282. if head_mask is not None:
  283. attention_probs = attention_probs * head_mask
  284. context_layer = torch.matmul(attention_probs, value_layer)
  285. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  286. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  287. context_layer = context_layer.view(*new_context_layer_shape)
  288. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  289. return outputs
  290. # Copied from transformers.models.beit.modeling_beit.BeitSelfOutput with Beit->Data2VecVision
  291. class Data2VecVisionSelfOutput(nn.Module):
  292. """
  293. The residual connection is defined in Data2VecVisionLayer instead of here (as is the case with other models), due to the
  294. layernorm applied before each block.
  295. """
  296. def __init__(self, config: Data2VecVisionConfig) -> None:
  297. super().__init__()
  298. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  299. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  300. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, gamma=None) -> torch.Tensor:
  301. hidden_states = self.dense(hidden_states)
  302. hidden_states = self.dropout(hidden_states)
  303. return hidden_states
  304. # Copied from transformers.models.beit.modeling_beit.BeitAttention with Beit->Data2VecVision
  305. class Data2VecVisionAttention(nn.Module):
  306. def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
  307. super().__init__()
  308. self.attention = Data2VecVisionSelfAttention(config, window_size=window_size)
  309. self.output = Data2VecVisionSelfOutput(config)
  310. self.pruned_heads = set()
  311. def prune_heads(self, heads):
  312. if len(heads) == 0:
  313. return
  314. heads, index = find_pruneable_heads_and_indices(
  315. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  316. )
  317. # Prune linear layers
  318. self.attention.query = prune_linear_layer(self.attention.query, index)
  319. self.attention.key = prune_linear_layer(self.attention.key, index)
  320. self.attention.value = prune_linear_layer(self.attention.value, index)
  321. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  322. # Update hyper params and store pruned heads
  323. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  324. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  325. self.pruned_heads = self.pruned_heads.union(heads)
  326. def forward(
  327. self,
  328. hidden_states: torch.Tensor,
  329. head_mask: Optional[torch.Tensor] = None,
  330. output_attentions: bool = False,
  331. relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
  332. interpolate_pos_encoding: bool = False,
  333. resolution: Optional[Tuple[int]] = None,
  334. ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
  335. self_outputs = self.attention(
  336. hidden_states, head_mask, output_attentions, relative_position_bias, interpolate_pos_encoding, resolution
  337. )
  338. attention_output = self.output(self_outputs[0], hidden_states)
  339. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  340. return outputs
  341. # Copied from transformers.models.beit.modeling_beit.BeitIntermediate with Beit->Data2VecVision
  342. class Data2VecVisionIntermediate(nn.Module):
  343. def __init__(self, config: Data2VecVisionConfig) -> None:
  344. super().__init__()
  345. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  346. if isinstance(config.hidden_act, str):
  347. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  348. else:
  349. self.intermediate_act_fn = config.hidden_act
  350. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  351. hidden_states = self.dense(hidden_states)
  352. hidden_states = self.intermediate_act_fn(hidden_states)
  353. return hidden_states
  354. # Copied from transformers.models.beit.modeling_beit.BeitOutput with Beit->Data2VecVision
  355. class Data2VecVisionOutput(nn.Module):
  356. def __init__(self, config: Data2VecVisionConfig) -> None:
  357. super().__init__()
  358. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  359. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  360. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  361. hidden_states = self.dense(hidden_states)
  362. hidden_states = self.dropout(hidden_states)
  363. return hidden_states
  364. # Copied from transformers.models.beit.modeling_beit.BeitLayer with Beit->Data2VecVision,BEiT->Data2VecVision
  365. class Data2VecVisionLayer(nn.Module):
  366. """This corresponds to the Block class in the timm implementation."""
  367. def __init__(
  368. self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None, drop_path_rate: float = 0.0
  369. ) -> None:
  370. super().__init__()
  371. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  372. self.seq_len_dim = 1
  373. self.attention = Data2VecVisionAttention(config, window_size=window_size)
  374. self.intermediate = Data2VecVisionIntermediate(config)
  375. self.output = Data2VecVisionOutput(config)
  376. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  377. self.drop_path = Data2VecVisionDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  378. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  379. init_values = config.layer_scale_init_value
  380. if init_values > 0:
  381. self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
  382. self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
  383. else:
  384. self.lambda_1, self.lambda_2 = None, None
  385. def forward(
  386. self,
  387. hidden_states: torch.Tensor,
  388. head_mask: Optional[torch.Tensor] = None,
  389. output_attentions: bool = False,
  390. relative_position_bias: Optional["Data2VecVisionRelativePositionBias"] = None,
  391. interpolate_pos_encoding: bool = False,
  392. resolution: Optional[Tuple[int]] = None,
  393. ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
  394. self_attention_outputs = self.attention(
  395. self.layernorm_before(hidden_states), # in Data2VecVision, layernorm is applied before self-attention
  396. head_mask,
  397. output_attentions=output_attentions,
  398. relative_position_bias=relative_position_bias,
  399. interpolate_pos_encoding=interpolate_pos_encoding,
  400. resolution=resolution,
  401. )
  402. attention_output = self_attention_outputs[0]
  403. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  404. # apply lambda_1 if present
  405. if self.lambda_1 is not None:
  406. attention_output = self.lambda_1 * attention_output
  407. # first residual connection
  408. hidden_states = self.drop_path(attention_output) + hidden_states
  409. # in Data2VecVision, layernorm is also applied after self-attention
  410. layer_output = self.layernorm_after(hidden_states)
  411. layer_output = self.intermediate(layer_output)
  412. layer_output = self.output(layer_output)
  413. if self.lambda_2 is not None:
  414. layer_output = self.lambda_2 * layer_output
  415. # second residual connection
  416. layer_output = self.drop_path(layer_output) + hidden_states
  417. outputs = (layer_output,) + outputs
  418. return outputs
  419. # Copied from transformers.models.beit.modeling_beit.BeitRelativePositionBias with Beit->Data2VecVision
  420. class Data2VecVisionRelativePositionBias(nn.Module):
  421. def __init__(self, config: Data2VecVisionConfig, window_size: tuple) -> None:
  422. super().__init__()
  423. self.window_size = window_size
  424. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  425. self.relative_position_bias_table = nn.Parameter(
  426. torch.zeros(self.num_relative_distance, config.num_attention_heads)
  427. ) # 2*Wh-1 * 2*Ww-1, nH
  428. # cls to token & token 2 cls & cls to cls
  429. self.relative_position_indices = {}
  430. def generate_relative_position_index(self, window_size: Tuple[int, int]) -> torch.Tensor:
  431. """
  432. This method creates the relative position index, modified to support arbitrary window sizes,
  433. as introduced in [MiDaS v3.1](https://arxiv.org/abs/2307.14460).
  434. """
  435. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  436. # cls to token & token 2 cls & cls to cls
  437. # get pair-wise relative position index for each token inside the window
  438. window_area = window_size[0] * window_size[1]
  439. grid = torch.meshgrid(torch.arange(window_size[0]), torch.arange(window_size[1]), indexing="ij")
  440. coords = torch.stack(grid) # 2, Wh, Ww
  441. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  442. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  443. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  444. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  445. relative_coords[:, :, 1] += window_size[1] - 1
  446. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  447. relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
  448. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  449. relative_position_index[0, 0:] = num_relative_distance - 3
  450. relative_position_index[0:, 0] = num_relative_distance - 2
  451. relative_position_index[0, 0] = num_relative_distance - 1
  452. return relative_position_index
  453. def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=None) -> torch.Tensor:
  454. """
  455. Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
  456. """
  457. old_height = 2 * self.window_size[0] - 1
  458. old_width = 2 * self.window_size[1] - 1
  459. new_height = 2 * window_size[0] - 1
  460. new_width = 2 * window_size[1] - 1
  461. old_relative_position_bias_table = self.relative_position_bias_table
  462. old_num_relative_distance = self.num_relative_distance
  463. new_num_relative_distance = new_height * new_width + 3
  464. old_sub_table = old_relative_position_bias_table[: old_num_relative_distance - 3]
  465. old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
  466. new_sub_table = nn.functional.interpolate(
  467. old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
  468. )
  469. new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
  470. new_relative_position_bias_table = torch.cat(
  471. [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3 :]]
  472. )
  473. key = window_size
  474. if key not in self.relative_position_indices.keys():
  475. self.relative_position_indices[key] = self.generate_relative_position_index(window_size)
  476. relative_position_bias = new_relative_position_bias_table[self.relative_position_indices[key].view(-1)]
  477. # patch_size*num_patches_height, patch_size*num_patches_width, num_attention_heads
  478. relative_position_bias = relative_position_bias.view(
  479. window_size[0] * window_size[1] + 1, window_size[0] * window_size[1] + 1, -1
  480. )
  481. # num_attention_heads, patch_size*num_patches_width, patch_size*num_patches_height
  482. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  483. if interpolate_pos_encoding:
  484. relative_position_bias = nn.functional.interpolate(
  485. relative_position_bias.unsqueeze(1),
  486. size=(dim_size, dim_size),
  487. mode="bilinear",
  488. align_corners=False,
  489. ).squeeze(1)
  490. return relative_position_bias.unsqueeze(0)
  491. # Copied from transformers.models.beit.modeling_beit.BeitEncoder with Beit->Data2VecVision
  492. class Data2VecVisionEncoder(nn.Module):
  493. def __init__(self, config: Data2VecVisionConfig, window_size: Optional[tuple] = None) -> None:
  494. super().__init__()
  495. self.config = config
  496. if config.use_shared_relative_position_bias:
  497. self.relative_position_bias = Data2VecVisionRelativePositionBias(config, window_size=window_size)
  498. else:
  499. self.relative_position_bias = None
  500. # stochastic depth decay rule
  501. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
  502. self.layer = nn.ModuleList(
  503. [
  504. Data2VecVisionLayer(
  505. config,
  506. window_size=window_size if config.use_relative_position_bias else None,
  507. drop_path_rate=dpr[i],
  508. )
  509. for i in range(config.num_hidden_layers)
  510. ]
  511. )
  512. self.gradient_checkpointing = False
  513. def forward(
  514. self,
  515. hidden_states: torch.Tensor,
  516. head_mask: Optional[torch.Tensor] = None,
  517. output_attentions: bool = False,
  518. output_hidden_states: bool = False,
  519. interpolate_pos_encoding: bool = False,
  520. resolution: Optional[Tuple[int]] = None,
  521. return_dict: bool = True,
  522. ) -> Union[tuple, BaseModelOutput]:
  523. all_hidden_states = () if output_hidden_states else None
  524. all_self_attentions = () if output_attentions else None
  525. for i, layer_module in enumerate(self.layer):
  526. if output_hidden_states:
  527. all_hidden_states = all_hidden_states + (hidden_states,)
  528. layer_head_mask = head_mask[i] if head_mask is not None else None
  529. if self.gradient_checkpointing and self.training:
  530. layer_outputs = self._gradient_checkpointing_func(
  531. layer_module.__call__,
  532. hidden_states,
  533. layer_head_mask,
  534. output_attentions,
  535. )
  536. else:
  537. height, width = resolution
  538. window_size = (height // self.config.patch_size, width // self.config.patch_size)
  539. relative_position_bias = (
  540. self.relative_position_bias(
  541. window_size, interpolate_pos_encoding=interpolate_pos_encoding, dim_size=hidden_states.shape[1]
  542. )
  543. if self.relative_position_bias is not None
  544. else None
  545. )
  546. layer_outputs = layer_module(
  547. hidden_states,
  548. layer_head_mask,
  549. output_attentions,
  550. relative_position_bias,
  551. interpolate_pos_encoding,
  552. resolution,
  553. )
  554. hidden_states = layer_outputs[0]
  555. if output_attentions:
  556. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  557. if output_hidden_states:
  558. all_hidden_states = all_hidden_states + (hidden_states,)
  559. if not return_dict:
  560. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  561. return BaseModelOutput(
  562. last_hidden_state=hidden_states,
  563. hidden_states=all_hidden_states,
  564. attentions=all_self_attentions,
  565. )
  566. # Copied from transformers.models.beit.modeling_beit.BeitPreTrainedModel with Beit->Data2VecVision,beit->data2vec_vision
  567. class Data2VecVisionPreTrainedModel(PreTrainedModel):
  568. """
  569. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  570. models.
  571. """
  572. config_class = Data2VecVisionConfig
  573. base_model_prefix = "data2vec_vision"
  574. main_input_name = "pixel_values"
  575. supports_gradient_checkpointing = True
  576. _no_split_modules = ["Data2VecVisionLayer"]
  577. _keys_to_ignore_on_load_unexpected = [r".*relative_position_index.*"]
  578. def _init_weights(self, module):
  579. """Initialize the weights"""
  580. if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
  581. # Slightly different from the TF version which uses truncated_normal for initialization
  582. # cf https://github.com/pytorch/pytorch/pull/5617
  583. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  584. if module.bias is not None:
  585. module.bias.data.zero_()
  586. elif isinstance(module, nn.Embedding):
  587. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  588. if module.padding_idx is not None:
  589. module.weight.data[module.padding_idx].zero_()
  590. elif isinstance(module, nn.LayerNorm):
  591. module.bias.data.zero_()
  592. module.weight.data.fill_(1.0)
  593. DATA2VEC_VISION_START_DOCSTRING = r"""
  594. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  595. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  596. behavior.
  597. Parameters:
  598. config ([`Data2VecVisionConfig`]): Model configuration class with all the parameters of the model.
  599. Initializing with a config file does not load the weights associated with the model, only the
  600. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  601. """
  602. DATA2VEC_VISION_INPUTS_DOCSTRING = r"""
  603. Args:
  604. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  605. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  606. [`BeitImageProcessor.__call__`] for details.
  607. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  608. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  609. - 1 indicates the head is **not masked**,
  610. - 0 indicates the head is **masked**.
  611. output_attentions (`bool`, *optional*):
  612. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  613. tensors for more detail.
  614. output_hidden_states (`bool`, *optional*):
  615. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  616. more detail.
  617. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  618. Whether to interpolate the pre-trained position encodings.
  619. return_dict (`bool`, *optional*):
  620. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  621. """
  622. @add_start_docstrings(
  623. "The bare Data2VecVision Model transformer outputting raw hidden-states without any specific head on top.",
  624. DATA2VEC_VISION_START_DOCSTRING,
  625. )
  626. # Copied from transformers.models.beit.modeling_beit.BeitModel with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,True->False
  627. class Data2VecVisionModel(Data2VecVisionPreTrainedModel):
  628. def __init__(self, config: Data2VecVisionConfig, add_pooling_layer: bool = False) -> None:
  629. super().__init__(config)
  630. self.config = config
  631. self.embeddings = Data2VecVisionEmbeddings(config)
  632. self.encoder = Data2VecVisionEncoder(config, window_size=self.embeddings.patch_embeddings.patch_shape)
  633. self.layernorm = (
  634. nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  635. )
  636. self.pooler = Data2VecVisionPooler(config) if add_pooling_layer else None
  637. # Initialize weights and apply final processing
  638. self.post_init()
  639. def get_input_embeddings(self):
  640. return self.embeddings.patch_embeddings
  641. def _prune_heads(self, heads_to_prune):
  642. """
  643. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  644. class PreTrainedModel
  645. """
  646. for layer, heads in heads_to_prune.items():
  647. self.encoder.layer[layer].attention.prune_heads(heads)
  648. @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
  649. @add_code_sample_docstrings(
  650. checkpoint=_CHECKPOINT_FOR_DOC,
  651. output_type=Data2VecVisionModelOutputWithPooling,
  652. config_class=_CONFIG_FOR_DOC,
  653. modality="vision",
  654. expected_output=_EXPECTED_OUTPUT_SHAPE,
  655. )
  656. def forward(
  657. self,
  658. pixel_values: torch.Tensor,
  659. bool_masked_pos: Optional[torch.BoolTensor] = None,
  660. head_mask: Optional[torch.Tensor] = None,
  661. output_attentions: Optional[bool] = None,
  662. output_hidden_states: Optional[bool] = None,
  663. interpolate_pos_encoding: bool = False,
  664. return_dict: Optional[bool] = None,
  665. ) -> Union[tuple, Data2VecVisionModelOutputWithPooling]:
  666. r"""
  667. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  668. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  669. """
  670. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  671. output_hidden_states = (
  672. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  673. )
  674. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  675. # Prepare head mask if needed
  676. # 1.0 in head_mask indicate we keep the head
  677. # attention_probs has shape bsz x n_heads x N x N
  678. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  679. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  680. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  681. embedding_output, _ = self.embeddings(
  682. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  683. )
  684. resolution = pixel_values.shape[2:]
  685. encoder_outputs = self.encoder(
  686. embedding_output,
  687. head_mask=head_mask,
  688. output_attentions=output_attentions,
  689. output_hidden_states=output_hidden_states,
  690. resolution=resolution,
  691. return_dict=return_dict,
  692. interpolate_pos_encoding=interpolate_pos_encoding,
  693. )
  694. sequence_output = encoder_outputs[0]
  695. sequence_output = self.layernorm(sequence_output)
  696. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  697. if not return_dict:
  698. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  699. return head_outputs + encoder_outputs[1:]
  700. return Data2VecVisionModelOutputWithPooling(
  701. last_hidden_state=sequence_output,
  702. pooler_output=pooled_output,
  703. hidden_states=encoder_outputs.hidden_states,
  704. attentions=encoder_outputs.attentions,
  705. )
  706. # Copied from transformers.models.beit.modeling_beit.BeitPooler with Beit->Data2VecVision
  707. class Data2VecVisionPooler(nn.Module):
  708. def __init__(self, config: Data2VecVisionConfig) -> None:
  709. super().__init__()
  710. self.layernorm = (
  711. nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
  712. )
  713. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  714. if self.layernorm is not None:
  715. # Mean pool the final hidden states of the patch tokens
  716. patch_tokens = hidden_states[:, 1:, :]
  717. pooled_output = self.layernorm(patch_tokens.mean(1))
  718. else:
  719. # Pool by simply taking the final hidden state of the [CLS] token
  720. pooled_output = hidden_states[:, 0]
  721. return pooled_output
  722. @add_start_docstrings(
  723. """
  724. Data2VecVision Model transformer with an image classification head on top (a linear layer on top of the average of
  725. the final hidden states of the patch tokens) e.g. for ImageNet.
  726. """,
  727. DATA2VEC_VISION_START_DOCSTRING,
  728. )
  729. # Copied from transformers.models.beit.modeling_beit.BeitForImageClassification with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,beit->data2vec_vision
  730. class Data2VecVisionForImageClassification(Data2VecVisionPreTrainedModel):
  731. def __init__(self, config: Data2VecVisionConfig) -> None:
  732. super().__init__(config)
  733. self.num_labels = config.num_labels
  734. self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=True)
  735. # Classifier head
  736. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  737. # Initialize weights and apply final processing
  738. self.post_init()
  739. @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
  740. @add_code_sample_docstrings(
  741. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  742. output_type=ImageClassifierOutput,
  743. config_class=_CONFIG_FOR_DOC,
  744. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  745. )
  746. def forward(
  747. self,
  748. pixel_values: Optional[torch.Tensor] = None,
  749. head_mask: Optional[torch.Tensor] = None,
  750. labels: Optional[torch.Tensor] = None,
  751. output_attentions: Optional[bool] = None,
  752. output_hidden_states: Optional[bool] = None,
  753. interpolate_pos_encoding: bool = False,
  754. return_dict: Optional[bool] = None,
  755. ) -> Union[tuple, ImageClassifierOutput]:
  756. r"""
  757. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  758. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  759. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  760. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  761. """
  762. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  763. outputs = self.data2vec_vision(
  764. pixel_values,
  765. head_mask=head_mask,
  766. output_attentions=output_attentions,
  767. output_hidden_states=output_hidden_states,
  768. interpolate_pos_encoding=interpolate_pos_encoding,
  769. return_dict=return_dict,
  770. )
  771. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  772. logits = self.classifier(pooled_output)
  773. loss = None
  774. if labels is not None:
  775. if self.config.problem_type is None:
  776. if self.num_labels == 1:
  777. self.config.problem_type = "regression"
  778. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  779. self.config.problem_type = "single_label_classification"
  780. else:
  781. self.config.problem_type = "multi_label_classification"
  782. if self.config.problem_type == "regression":
  783. loss_fct = MSELoss()
  784. if self.num_labels == 1:
  785. loss = loss_fct(logits.squeeze(), labels.squeeze())
  786. else:
  787. loss = loss_fct(logits, labels)
  788. elif self.config.problem_type == "single_label_classification":
  789. loss_fct = CrossEntropyLoss()
  790. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  791. elif self.config.problem_type == "multi_label_classification":
  792. loss_fct = BCEWithLogitsLoss()
  793. loss = loss_fct(logits, labels)
  794. if not return_dict:
  795. output = (logits,) + outputs[2:]
  796. return ((loss,) + output) if loss is not None else output
  797. return ImageClassifierOutput(
  798. loss=loss,
  799. logits=logits,
  800. hidden_states=outputs.hidden_states,
  801. attentions=outputs.attentions,
  802. )
  803. # Copied from transformers.models.beit.modeling_beit.BeitConvModule with Beit->Data2VecVision
  804. class Data2VecVisionConvModule(nn.Module):
  805. """
  806. A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
  807. layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
  808. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  809. """
  810. def __init__(
  811. self,
  812. in_channels: int,
  813. out_channels: int,
  814. kernel_size: Union[int, Tuple[int, int]],
  815. padding: Union[int, Tuple[int, int], str] = 0,
  816. bias: bool = False,
  817. dilation: Union[int, Tuple[int, int]] = 1,
  818. ) -> None:
  819. super().__init__()
  820. self.conv = nn.Conv2d(
  821. in_channels=in_channels,
  822. out_channels=out_channels,
  823. kernel_size=kernel_size,
  824. padding=padding,
  825. bias=bias,
  826. dilation=dilation,
  827. )
  828. self.bn = nn.BatchNorm2d(out_channels)
  829. self.activation = nn.ReLU()
  830. def forward(self, input: torch.Tensor) -> torch.Tensor:
  831. output = self.conv(input)
  832. output = self.bn(output)
  833. output = self.activation(output)
  834. return output
  835. # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingBlock with Beit->Data2VecVision
  836. class Data2VecVisionPyramidPoolingBlock(nn.Module):
  837. def __init__(self, pool_scale: int, in_channels: int, channels: int) -> None:
  838. super().__init__()
  839. self.layers = [
  840. nn.AdaptiveAvgPool2d(pool_scale),
  841. Data2VecVisionConvModule(in_channels, channels, kernel_size=1),
  842. ]
  843. for i, layer in enumerate(self.layers):
  844. self.add_module(str(i), layer)
  845. def forward(self, input: torch.Tensor) -> torch.Tensor:
  846. hidden_state = input
  847. for layer in self.layers:
  848. hidden_state = layer(hidden_state)
  849. return hidden_state
  850. # Copied from transformers.models.beit.modeling_beit.BeitPyramidPoolingModule with Beit->Data2VecVision
  851. class Data2VecVisionPyramidPoolingModule(nn.Module):
  852. """
  853. Pyramid Pooling Module (PPM) used in PSPNet.
  854. Args:
  855. pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
  856. Module.
  857. in_channels (int): Input channels.
  858. channels (int): Channels after modules, before conv_seg.
  859. align_corners (bool): align_corners argument of F.interpolate.
  860. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  861. """
  862. def __init__(self, pool_scales: Tuple[int, ...], in_channels: int, channels: int, align_corners: bool) -> None:
  863. super().__init__()
  864. self.pool_scales = pool_scales
  865. self.align_corners = align_corners
  866. self.in_channels = in_channels
  867. self.channels = channels
  868. self.blocks = []
  869. for i, pool_scale in enumerate(pool_scales):
  870. block = Data2VecVisionPyramidPoolingBlock(
  871. pool_scale=pool_scale, in_channels=in_channels, channels=channels
  872. )
  873. self.blocks.append(block)
  874. self.add_module(str(i), block)
  875. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  876. ppm_outs = []
  877. for ppm in self.blocks:
  878. ppm_out = ppm(x)
  879. upsampled_ppm_out = nn.functional.interpolate(
  880. ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
  881. )
  882. ppm_outs.append(upsampled_ppm_out)
  883. return ppm_outs
  884. # Copied from transformers.models.beit.modeling_beit.BeitUperHead with Beit->Data2VecVision
  885. class Data2VecVisionUperHead(nn.Module):
  886. """
  887. Unified Perceptual Parsing for Scene Understanding. This head is the implementation of
  888. [UPerNet](https://arxiv.org/abs/1807.10221).
  889. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  890. """
  891. def __init__(self, config: Data2VecVisionConfig) -> None:
  892. super().__init__()
  893. self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
  894. self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
  895. self.channels = config.hidden_size
  896. self.align_corners = False
  897. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  898. # PSP Module
  899. self.psp_modules = Data2VecVisionPyramidPoolingModule(
  900. self.pool_scales,
  901. self.in_channels[-1],
  902. self.channels,
  903. align_corners=self.align_corners,
  904. )
  905. self.bottleneck = Data2VecVisionConvModule(
  906. self.in_channels[-1] + len(self.pool_scales) * self.channels,
  907. self.channels,
  908. kernel_size=3,
  909. padding=1,
  910. )
  911. # FPN Module
  912. self.lateral_convs = nn.ModuleList()
  913. self.fpn_convs = nn.ModuleList()
  914. for in_channels in self.in_channels[:-1]: # skip the top layer
  915. l_conv = Data2VecVisionConvModule(in_channels, self.channels, kernel_size=1)
  916. fpn_conv = Data2VecVisionConvModule(self.channels, self.channels, kernel_size=3, padding=1)
  917. self.lateral_convs.append(l_conv)
  918. self.fpn_convs.append(fpn_conv)
  919. self.fpn_bottleneck = Data2VecVisionConvModule(
  920. len(self.in_channels) * self.channels,
  921. self.channels,
  922. kernel_size=3,
  923. padding=1,
  924. )
  925. def psp_forward(self, inputs):
  926. x = inputs[-1]
  927. psp_outs = [x]
  928. psp_outs.extend(self.psp_modules(x))
  929. psp_outs = torch.cat(psp_outs, dim=1)
  930. output = self.bottleneck(psp_outs)
  931. return output
  932. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  933. # build laterals
  934. laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
  935. laterals.append(self.psp_forward(encoder_hidden_states))
  936. # build top-down path
  937. used_backbone_levels = len(laterals)
  938. for i in range(used_backbone_levels - 1, 0, -1):
  939. prev_shape = laterals[i - 1].shape[2:]
  940. laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
  941. laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
  942. )
  943. # build outputs
  944. fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
  945. # append psp feature
  946. fpn_outs.append(laterals[-1])
  947. for i in range(used_backbone_levels - 1, 0, -1):
  948. fpn_outs[i] = nn.functional.interpolate(
  949. fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
  950. )
  951. fpn_outs = torch.cat(fpn_outs, dim=1)
  952. output = self.fpn_bottleneck(fpn_outs)
  953. output = self.classifier(output)
  954. return output
  955. # Copied from transformers.models.beit.modeling_beit.BeitFCNHead with Beit->Data2VecVision
  956. class Data2VecVisionFCNHead(nn.Module):
  957. """
  958. Fully Convolution Networks for Semantic Segmentation. This head is implemented of
  959. [FCNNet](https://arxiv.org/abs/1411.4038>).
  960. Args:
  961. config (Data2VecVisionConfig): Configuration.
  962. in_channels
  963. kernel_size (int): The kernel size for convs in the head. Default: 3.
  964. dilation (int): The dilation rate for convs in the head. Default: 1.
  965. Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
  966. """
  967. def __init__(
  968. self,
  969. config: Data2VecVisionConfig,
  970. in_index: int = 2,
  971. kernel_size: int = 3,
  972. dilation: Union[int, Tuple[int, int]] = 1,
  973. ) -> None:
  974. super().__init__()
  975. self.in_channels = config.hidden_size
  976. self.channels = config.auxiliary_channels
  977. self.num_convs = config.auxiliary_num_convs
  978. self.concat_input = config.auxiliary_concat_input
  979. self.in_index = in_index
  980. conv_padding = (kernel_size // 2) * dilation
  981. convs = []
  982. convs.append(
  983. Data2VecVisionConvModule(
  984. self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  985. )
  986. )
  987. for i in range(self.num_convs - 1):
  988. convs.append(
  989. Data2VecVisionConvModule(
  990. self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
  991. )
  992. )
  993. if self.num_convs == 0:
  994. self.convs = nn.Identity()
  995. else:
  996. self.convs = nn.Sequential(*convs)
  997. if self.concat_input:
  998. self.conv_cat = Data2VecVisionConvModule(
  999. self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
  1000. )
  1001. self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
  1002. def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
  1003. # just take the relevant feature maps
  1004. hidden_states = encoder_hidden_states[self.in_index]
  1005. output = self.convs(hidden_states)
  1006. if self.concat_input:
  1007. output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
  1008. output = self.classifier(output)
  1009. return output
  1010. @add_start_docstrings(
  1011. """
  1012. Data2VecVision Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
  1013. """,
  1014. DATA2VEC_VISION_START_DOCSTRING,
  1015. )
  1016. # Copied from transformers.models.beit.modeling_beit.BeitForSemanticSegmentation with BEIT->DATA2VEC_VISION,Beit->Data2VecVision,microsoft/beit-base-finetuned-ade-640-640->facebook/data2vec-vision-base,beit->data2vec_vision
  1017. class Data2VecVisionForSemanticSegmentation(Data2VecVisionPreTrainedModel):
  1018. def __init__(self, config: Data2VecVisionConfig) -> None:
  1019. super().__init__(config)
  1020. self.num_labels = config.num_labels
  1021. self.data2vec_vision = Data2VecVisionModel(config, add_pooling_layer=False)
  1022. # FPNs
  1023. if len(self.config.out_indices) != 4:
  1024. raise ValueError(
  1025. "Data2VecVisionForSemanticSegmentation requires config.out_indices to be a list of 4 integers, "
  1026. "specifying which features to use from the backbone. One can use [3, 5, 7, 11] in case of "
  1027. "a base-sized architecture."
  1028. )
  1029. self.fpn1 = nn.Sequential(
  1030. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1031. nn.BatchNorm2d(config.hidden_size),
  1032. nn.GELU(),
  1033. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1034. )
  1035. self.fpn2 = nn.Sequential(
  1036. nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
  1037. )
  1038. self.fpn3 = nn.Identity()
  1039. self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
  1040. # Semantic segmentation head(s)
  1041. self.decode_head = Data2VecVisionUperHead(config)
  1042. self.auxiliary_head = Data2VecVisionFCNHead(config) if config.use_auxiliary_head else None
  1043. # Initialize weights and apply final processing
  1044. self.post_init()
  1045. def compute_loss(self, logits, auxiliary_logits, labels):
  1046. # upsample logits to the images' original size
  1047. upsampled_logits = nn.functional.interpolate(
  1048. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1049. )
  1050. if auxiliary_logits is not None:
  1051. upsampled_auxiliary_logits = nn.functional.interpolate(
  1052. auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  1053. )
  1054. # compute weighted loss
  1055. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  1056. main_loss = loss_fct(upsampled_logits, labels)
  1057. loss = main_loss
  1058. if auxiliary_logits is not None:
  1059. auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
  1060. loss += self.config.auxiliary_loss_weight * auxiliary_loss
  1061. return loss
  1062. @add_start_docstrings_to_model_forward(DATA2VEC_VISION_INPUTS_DOCSTRING)
  1063. @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
  1064. def forward(
  1065. self,
  1066. pixel_values: Optional[torch.Tensor] = None,
  1067. head_mask: Optional[torch.Tensor] = None,
  1068. labels: Optional[torch.Tensor] = None,
  1069. output_attentions: Optional[bool] = None,
  1070. output_hidden_states: Optional[bool] = None,
  1071. interpolate_pos_encoding: bool = False,
  1072. return_dict: Optional[bool] = None,
  1073. ) -> Union[tuple, SemanticSegmenterOutput]:
  1074. r"""
  1075. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  1076. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  1077. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  1078. Returns:
  1079. Examples:
  1080. ```python
  1081. >>> from transformers import AutoImageProcessor, Data2VecVisionForSemanticSegmentation
  1082. >>> from PIL import Image
  1083. >>> import requests
  1084. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1085. >>> image = Image.open(requests.get(url, stream=True).raw)
  1086. >>> image_processor = AutoImageProcessor.from_pretrained("facebook/data2vec-vision-base")
  1087. >>> model = Data2VecVisionForSemanticSegmentation.from_pretrained("facebook/data2vec-vision-base")
  1088. >>> inputs = image_processor(images=image, return_tensors="pt")
  1089. >>> outputs = model(**inputs)
  1090. >>> # logits are of shape (batch_size, num_labels, height, width)
  1091. >>> logits = outputs.logits
  1092. ```"""
  1093. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1094. output_hidden_states = (
  1095. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1096. )
  1097. if labels is not None and self.config.num_labels == 1:
  1098. raise ValueError("The number of labels should be greater than one")
  1099. outputs = self.data2vec_vision(
  1100. pixel_values,
  1101. head_mask=head_mask,
  1102. output_attentions=output_attentions,
  1103. output_hidden_states=True, # we need the intermediate hidden states
  1104. interpolate_pos_encoding=interpolate_pos_encoding,
  1105. return_dict=return_dict,
  1106. )
  1107. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  1108. # only keep certain features, and reshape
  1109. # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
  1110. features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
  1111. batch_size = pixel_values.shape[0]
  1112. patch_resolution = self.config.image_size // self.config.patch_size
  1113. features = [
  1114. x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
  1115. ]
  1116. # apply FPNs
  1117. ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
  1118. for i in range(len(features)):
  1119. features[i] = ops[i](features[i])
  1120. logits = self.decode_head(features)
  1121. auxiliary_logits = None
  1122. if self.auxiliary_head is not None:
  1123. auxiliary_logits = self.auxiliary_head(features)
  1124. loss = None
  1125. if labels is not None:
  1126. loss = self.compute_loss(logits, auxiliary_logits, labels)
  1127. if not return_dict:
  1128. if output_hidden_states:
  1129. output = (logits,) + outputs[1:]
  1130. else:
  1131. output = (logits,) + outputs[2:]
  1132. return ((loss,) + output) if loss is not None else output
  1133. return SemanticSegmenterOutput(
  1134. loss=loss,
  1135. logits=logits,
  1136. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1137. attentions=outputs.attentions,
  1138. )