modeling_vit.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. # coding=utf-8
  2. # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ViT model."""
  16. import collections.abc
  17. import math
  18. from typing import Dict, List, Optional, Set, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. BaseModelOutputWithPooling,
  27. ImageClassifierOutput,
  28. MaskedImageModelingOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  32. from ...utils import (
  33. add_code_sample_docstrings,
  34. add_start_docstrings,
  35. add_start_docstrings_to_model_forward,
  36. logging,
  37. replace_return_docstrings,
  38. torch_int,
  39. )
  40. from .configuration_vit import ViTConfig
  41. logger = logging.get_logger(__name__)
  42. # General docstring
  43. _CONFIG_FOR_DOC = "ViTConfig"
  44. # Base docstring
  45. _CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224-in21k"
  46. _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
  47. # Image classification docstring
  48. _IMAGE_CLASS_CHECKPOINT = "google/vit-base-patch16-224"
  49. _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
  50. class ViTEmbeddings(nn.Module):
  51. """
  52. Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
  53. """
  54. def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
  55. super().__init__()
  56. self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
  57. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
  58. self.patch_embeddings = ViTPatchEmbeddings(config)
  59. num_patches = self.patch_embeddings.num_patches
  60. self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
  61. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  62. self.patch_size = config.patch_size
  63. self.config = config
  64. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  65. """
  66. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  67. images. This method is also adapted to support torch.jit tracing.
  68. Adapted from:
  69. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  70. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  71. """
  72. num_patches = embeddings.shape[1] - 1
  73. num_positions = self.position_embeddings.shape[1] - 1
  74. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  75. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  76. return self.position_embeddings
  77. class_pos_embed = self.position_embeddings[:, :1]
  78. patch_pos_embed = self.position_embeddings[:, 1:]
  79. dim = embeddings.shape[-1]
  80. new_height = height // self.patch_size
  81. new_width = width // self.patch_size
  82. sqrt_num_positions = torch_int(num_positions**0.5)
  83. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  84. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  85. patch_pos_embed = nn.functional.interpolate(
  86. patch_pos_embed,
  87. size=(new_height, new_width),
  88. mode="bicubic",
  89. align_corners=False,
  90. )
  91. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  92. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  93. def forward(
  94. self,
  95. pixel_values: torch.Tensor,
  96. bool_masked_pos: Optional[torch.BoolTensor] = None,
  97. interpolate_pos_encoding: bool = False,
  98. ) -> torch.Tensor:
  99. batch_size, num_channels, height, width = pixel_values.shape
  100. embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  101. if bool_masked_pos is not None:
  102. seq_length = embeddings.shape[1]
  103. mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
  104. # replace the masked visual tokens by mask_tokens
  105. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  106. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  107. # add the [CLS] token to the embedded patch tokens
  108. cls_tokens = self.cls_token.expand(batch_size, -1, -1)
  109. embeddings = torch.cat((cls_tokens, embeddings), dim=1)
  110. # add positional encoding to each token
  111. if interpolate_pos_encoding:
  112. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  113. else:
  114. embeddings = embeddings + self.position_embeddings
  115. embeddings = self.dropout(embeddings)
  116. return embeddings
  117. class ViTPatchEmbeddings(nn.Module):
  118. """
  119. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  120. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  121. Transformer.
  122. """
  123. def __init__(self, config):
  124. super().__init__()
  125. image_size, patch_size = config.image_size, config.patch_size
  126. num_channels, hidden_size = config.num_channels, config.hidden_size
  127. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  128. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  129. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  130. self.image_size = image_size
  131. self.patch_size = patch_size
  132. self.num_channels = num_channels
  133. self.num_patches = num_patches
  134. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  135. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  136. batch_size, num_channels, height, width = pixel_values.shape
  137. if num_channels != self.num_channels:
  138. raise ValueError(
  139. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  140. f" Expected {self.num_channels} but got {num_channels}."
  141. )
  142. if not interpolate_pos_encoding:
  143. if height != self.image_size[0] or width != self.image_size[1]:
  144. raise ValueError(
  145. f"Input image size ({height}*{width}) doesn't match model"
  146. f" ({self.image_size[0]}*{self.image_size[1]})."
  147. )
  148. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  149. return embeddings
  150. class ViTSelfAttention(nn.Module):
  151. def __init__(self, config: ViTConfig) -> None:
  152. super().__init__()
  153. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  154. raise ValueError(
  155. f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
  156. f"heads {config.num_attention_heads}."
  157. )
  158. self.num_attention_heads = config.num_attention_heads
  159. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  160. self.all_head_size = self.num_attention_heads * self.attention_head_size
  161. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  162. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  163. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
  164. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  165. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  166. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  167. x = x.view(new_x_shape)
  168. return x.permute(0, 2, 1, 3)
  169. def forward(
  170. self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
  171. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  172. mixed_query_layer = self.query(hidden_states)
  173. key_layer = self.transpose_for_scores(self.key(hidden_states))
  174. value_layer = self.transpose_for_scores(self.value(hidden_states))
  175. query_layer = self.transpose_for_scores(mixed_query_layer)
  176. # Take the dot product between "query" and "key" to get the raw attention scores.
  177. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  178. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  179. # Normalize the attention scores to probabilities.
  180. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  181. # This is actually dropping out entire tokens to attend to, which might
  182. # seem a bit unusual, but is taken from the original Transformer paper.
  183. attention_probs = self.dropout(attention_probs)
  184. # Mask heads if we want to
  185. if head_mask is not None:
  186. attention_probs = attention_probs * head_mask
  187. context_layer = torch.matmul(attention_probs, value_layer)
  188. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  189. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  190. context_layer = context_layer.view(new_context_layer_shape)
  191. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  192. return outputs
  193. class ViTSdpaSelfAttention(ViTSelfAttention):
  194. def __init__(self, config: ViTConfig) -> None:
  195. super().__init__(config)
  196. self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
  197. def forward(
  198. self,
  199. hidden_states: torch.FloatTensor,
  200. head_mask: Optional[torch.Tensor] = None,
  201. output_attentions: bool = False,
  202. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  203. if output_attentions or head_mask is not None:
  204. logger.warning_once(
  205. "`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  206. "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
  207. "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
  208. 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  209. )
  210. return super().forward(
  211. hidden_states=hidden_states,
  212. head_mask=head_mask,
  213. output_attentions=output_attentions,
  214. )
  215. mixed_query_layer = self.query(hidden_states)
  216. key_layer = self.transpose_for_scores(self.key(hidden_states))
  217. value_layer = self.transpose_for_scores(self.value(hidden_states))
  218. query_layer = self.transpose_for_scores(mixed_query_layer)
  219. context_layer = torch.nn.functional.scaled_dot_product_attention(
  220. query_layer,
  221. key_layer,
  222. value_layer,
  223. head_mask,
  224. self.attention_probs_dropout_prob if self.training else 0.0,
  225. is_causal=False,
  226. scale=None,
  227. )
  228. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  229. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  230. context_layer = context_layer.view(new_context_layer_shape)
  231. return context_layer, None
  232. class ViTSelfOutput(nn.Module):
  233. """
  234. The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
  235. layernorm applied before each block.
  236. """
  237. def __init__(self, config: ViTConfig) -> None:
  238. super().__init__()
  239. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  240. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  241. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  242. hidden_states = self.dense(hidden_states)
  243. hidden_states = self.dropout(hidden_states)
  244. return hidden_states
  245. class ViTAttention(nn.Module):
  246. def __init__(self, config: ViTConfig) -> None:
  247. super().__init__()
  248. self.attention = ViTSelfAttention(config)
  249. self.output = ViTSelfOutput(config)
  250. self.pruned_heads = set()
  251. def prune_heads(self, heads: Set[int]) -> None:
  252. if len(heads) == 0:
  253. return
  254. heads, index = find_pruneable_heads_and_indices(
  255. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  256. )
  257. # Prune linear layers
  258. self.attention.query = prune_linear_layer(self.attention.query, index)
  259. self.attention.key = prune_linear_layer(self.attention.key, index)
  260. self.attention.value = prune_linear_layer(self.attention.value, index)
  261. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  262. # Update hyper params and store pruned heads
  263. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  264. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  265. self.pruned_heads = self.pruned_heads.union(heads)
  266. def forward(
  267. self,
  268. hidden_states: torch.Tensor,
  269. head_mask: Optional[torch.Tensor] = None,
  270. output_attentions: bool = False,
  271. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  272. self_outputs = self.attention(hidden_states, head_mask, output_attentions)
  273. attention_output = self.output(self_outputs[0], hidden_states)
  274. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  275. return outputs
  276. class ViTSdpaAttention(ViTAttention):
  277. def __init__(self, config: ViTConfig) -> None:
  278. super().__init__(config)
  279. self.attention = ViTSdpaSelfAttention(config)
  280. class ViTIntermediate(nn.Module):
  281. def __init__(self, config: ViTConfig) -> None:
  282. super().__init__()
  283. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  284. if isinstance(config.hidden_act, str):
  285. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  286. else:
  287. self.intermediate_act_fn = config.hidden_act
  288. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  289. hidden_states = self.dense(hidden_states)
  290. hidden_states = self.intermediate_act_fn(hidden_states)
  291. return hidden_states
  292. class ViTOutput(nn.Module):
  293. def __init__(self, config: ViTConfig) -> None:
  294. super().__init__()
  295. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  296. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  297. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  298. hidden_states = self.dense(hidden_states)
  299. hidden_states = self.dropout(hidden_states)
  300. hidden_states = hidden_states + input_tensor
  301. return hidden_states
  302. VIT_ATTENTION_CLASSES = {
  303. "eager": ViTAttention,
  304. "sdpa": ViTSdpaAttention,
  305. }
  306. class ViTLayer(nn.Module):
  307. """This corresponds to the Block class in the timm implementation."""
  308. def __init__(self, config: ViTConfig) -> None:
  309. super().__init__()
  310. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  311. self.seq_len_dim = 1
  312. self.attention = VIT_ATTENTION_CLASSES[config._attn_implementation](config)
  313. self.intermediate = ViTIntermediate(config)
  314. self.output = ViTOutput(config)
  315. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  316. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  317. def forward(
  318. self,
  319. hidden_states: torch.Tensor,
  320. head_mask: Optional[torch.Tensor] = None,
  321. output_attentions: bool = False,
  322. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  323. self_attention_outputs = self.attention(
  324. self.layernorm_before(hidden_states), # in ViT, layernorm is applied before self-attention
  325. head_mask,
  326. output_attentions=output_attentions,
  327. )
  328. attention_output = self_attention_outputs[0]
  329. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  330. # first residual connection
  331. hidden_states = attention_output + hidden_states
  332. # in ViT, layernorm is also applied after self-attention
  333. layer_output = self.layernorm_after(hidden_states)
  334. layer_output = self.intermediate(layer_output)
  335. # second residual connection is done here
  336. layer_output = self.output(layer_output, hidden_states)
  337. outputs = (layer_output,) + outputs
  338. return outputs
  339. class ViTEncoder(nn.Module):
  340. def __init__(self, config: ViTConfig) -> None:
  341. super().__init__()
  342. self.config = config
  343. self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
  344. self.gradient_checkpointing = False
  345. def forward(
  346. self,
  347. hidden_states: torch.Tensor,
  348. head_mask: Optional[torch.Tensor] = None,
  349. output_attentions: bool = False,
  350. output_hidden_states: bool = False,
  351. return_dict: bool = True,
  352. ) -> Union[tuple, BaseModelOutput]:
  353. all_hidden_states = () if output_hidden_states else None
  354. all_self_attentions = () if output_attentions else None
  355. for i, layer_module in enumerate(self.layer):
  356. if output_hidden_states:
  357. all_hidden_states = all_hidden_states + (hidden_states,)
  358. layer_head_mask = head_mask[i] if head_mask is not None else None
  359. if self.gradient_checkpointing and self.training:
  360. layer_outputs = self._gradient_checkpointing_func(
  361. layer_module.__call__,
  362. hidden_states,
  363. layer_head_mask,
  364. output_attentions,
  365. )
  366. else:
  367. layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
  368. hidden_states = layer_outputs[0]
  369. if output_attentions:
  370. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  371. if output_hidden_states:
  372. all_hidden_states = all_hidden_states + (hidden_states,)
  373. if not return_dict:
  374. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  375. return BaseModelOutput(
  376. last_hidden_state=hidden_states,
  377. hidden_states=all_hidden_states,
  378. attentions=all_self_attentions,
  379. )
  380. class ViTPreTrainedModel(PreTrainedModel):
  381. """
  382. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  383. models.
  384. """
  385. config_class = ViTConfig
  386. base_model_prefix = "vit"
  387. main_input_name = "pixel_values"
  388. supports_gradient_checkpointing = True
  389. _no_split_modules = ["ViTEmbeddings", "ViTLayer"]
  390. _supports_sdpa = True
  391. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  392. """Initialize the weights"""
  393. if isinstance(module, (nn.Linear, nn.Conv2d)):
  394. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  395. # `trunc_normal_cpu` not implemented in `half` issues
  396. module.weight.data = nn.init.trunc_normal_(
  397. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  398. ).to(module.weight.dtype)
  399. if module.bias is not None:
  400. module.bias.data.zero_()
  401. elif isinstance(module, nn.LayerNorm):
  402. module.bias.data.zero_()
  403. module.weight.data.fill_(1.0)
  404. elif isinstance(module, ViTEmbeddings):
  405. module.position_embeddings.data = nn.init.trunc_normal_(
  406. module.position_embeddings.data.to(torch.float32),
  407. mean=0.0,
  408. std=self.config.initializer_range,
  409. ).to(module.position_embeddings.dtype)
  410. module.cls_token.data = nn.init.trunc_normal_(
  411. module.cls_token.data.to(torch.float32),
  412. mean=0.0,
  413. std=self.config.initializer_range,
  414. ).to(module.cls_token.dtype)
  415. VIT_START_DOCSTRING = r"""
  416. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  417. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  418. behavior.
  419. Parameters:
  420. config ([`ViTConfig`]): Model configuration class with all the parameters of the model.
  421. Initializing with a config file does not load the weights associated with the model, only the
  422. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  423. """
  424. VIT_INPUTS_DOCSTRING = r"""
  425. Args:
  426. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  427. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See [`ViTImageProcessor.__call__`]
  428. for details.
  429. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  430. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  431. - 1 indicates the head is **not masked**,
  432. - 0 indicates the head is **masked**.
  433. output_attentions (`bool`, *optional*):
  434. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  435. tensors for more detail.
  436. output_hidden_states (`bool`, *optional*):
  437. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  438. more detail.
  439. interpolate_pos_encoding (`bool`, *optional*):
  440. Whether to interpolate the pre-trained position encodings.
  441. return_dict (`bool`, *optional*):
  442. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  443. """
  444. @add_start_docstrings(
  445. "The bare ViT Model transformer outputting raw hidden-states without any specific head on top.",
  446. VIT_START_DOCSTRING,
  447. )
  448. class ViTModel(ViTPreTrainedModel):
  449. def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
  450. super().__init__(config)
  451. self.config = config
  452. self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
  453. self.encoder = ViTEncoder(config)
  454. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  455. self.pooler = ViTPooler(config) if add_pooling_layer else None
  456. # Initialize weights and apply final processing
  457. self.post_init()
  458. def get_input_embeddings(self) -> ViTPatchEmbeddings:
  459. return self.embeddings.patch_embeddings
  460. def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
  461. """
  462. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  463. class PreTrainedModel
  464. """
  465. for layer, heads in heads_to_prune.items():
  466. self.encoder.layer[layer].attention.prune_heads(heads)
  467. @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
  468. @add_code_sample_docstrings(
  469. checkpoint=_CHECKPOINT_FOR_DOC,
  470. output_type=BaseModelOutputWithPooling,
  471. config_class=_CONFIG_FOR_DOC,
  472. modality="vision",
  473. expected_output=_EXPECTED_OUTPUT_SHAPE,
  474. )
  475. def forward(
  476. self,
  477. pixel_values: Optional[torch.Tensor] = None,
  478. bool_masked_pos: Optional[torch.BoolTensor] = None,
  479. head_mask: Optional[torch.Tensor] = None,
  480. output_attentions: Optional[bool] = None,
  481. output_hidden_states: Optional[bool] = None,
  482. interpolate_pos_encoding: Optional[bool] = None,
  483. return_dict: Optional[bool] = None,
  484. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  485. r"""
  486. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  487. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  488. """
  489. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  490. output_hidden_states = (
  491. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  492. )
  493. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  494. if pixel_values is None:
  495. raise ValueError("You have to specify pixel_values")
  496. # Prepare head mask if needed
  497. # 1.0 in head_mask indicate we keep the head
  498. # attention_probs has shape bsz x n_heads x N x N
  499. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  500. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  501. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  502. # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
  503. expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
  504. if pixel_values.dtype != expected_dtype:
  505. pixel_values = pixel_values.to(expected_dtype)
  506. embedding_output = self.embeddings(
  507. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  508. )
  509. encoder_outputs = self.encoder(
  510. embedding_output,
  511. head_mask=head_mask,
  512. output_attentions=output_attentions,
  513. output_hidden_states=output_hidden_states,
  514. return_dict=return_dict,
  515. )
  516. sequence_output = encoder_outputs[0]
  517. sequence_output = self.layernorm(sequence_output)
  518. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  519. if not return_dict:
  520. head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
  521. return head_outputs + encoder_outputs[1:]
  522. return BaseModelOutputWithPooling(
  523. last_hidden_state=sequence_output,
  524. pooler_output=pooled_output,
  525. hidden_states=encoder_outputs.hidden_states,
  526. attentions=encoder_outputs.attentions,
  527. )
  528. class ViTPooler(nn.Module):
  529. def __init__(self, config: ViTConfig):
  530. super().__init__()
  531. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  532. self.activation = nn.Tanh()
  533. def forward(self, hidden_states):
  534. # We "pool" the model by simply taking the hidden state corresponding
  535. # to the first token.
  536. first_token_tensor = hidden_states[:, 0]
  537. pooled_output = self.dense(first_token_tensor)
  538. pooled_output = self.activation(pooled_output)
  539. return pooled_output
  540. @add_start_docstrings(
  541. """ViT Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://arxiv.org/abs/2111.09886).
  542. <Tip>
  543. Note that we provide a script to pre-train this model on custom data in our [examples
  544. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  545. </Tip>
  546. """,
  547. VIT_START_DOCSTRING,
  548. )
  549. class ViTForMaskedImageModeling(ViTPreTrainedModel):
  550. def __init__(self, config: ViTConfig) -> None:
  551. super().__init__(config)
  552. self.vit = ViTModel(config, add_pooling_layer=False, use_mask_token=True)
  553. self.decoder = nn.Sequential(
  554. nn.Conv2d(
  555. in_channels=config.hidden_size,
  556. out_channels=config.encoder_stride**2 * config.num_channels,
  557. kernel_size=1,
  558. ),
  559. nn.PixelShuffle(config.encoder_stride),
  560. )
  561. # Initialize weights and apply final processing
  562. self.post_init()
  563. @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
  564. @replace_return_docstrings(output_type=MaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC)
  565. def forward(
  566. self,
  567. pixel_values: Optional[torch.Tensor] = None,
  568. bool_masked_pos: Optional[torch.BoolTensor] = None,
  569. head_mask: Optional[torch.Tensor] = None,
  570. output_attentions: Optional[bool] = None,
  571. output_hidden_states: Optional[bool] = None,
  572. interpolate_pos_encoding: Optional[bool] = None,
  573. return_dict: Optional[bool] = None,
  574. ) -> Union[tuple, MaskedImageModelingOutput]:
  575. r"""
  576. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  577. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  578. Returns:
  579. Examples:
  580. ```python
  581. >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling
  582. >>> import torch
  583. >>> from PIL import Image
  584. >>> import requests
  585. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  586. >>> image = Image.open(requests.get(url, stream=True).raw)
  587. >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
  588. >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k")
  589. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  590. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  591. >>> # create random boolean mask of shape (batch_size, num_patches)
  592. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  593. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  594. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  595. >>> list(reconstructed_pixel_values.shape)
  596. [1, 3, 224, 224]
  597. ```"""
  598. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  599. if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride):
  600. raise ValueError(
  601. "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that "
  602. "the reconstructed image has the same dimensions as the input. "
  603. f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}."
  604. )
  605. outputs = self.vit(
  606. pixel_values,
  607. bool_masked_pos=bool_masked_pos,
  608. head_mask=head_mask,
  609. output_attentions=output_attentions,
  610. output_hidden_states=output_hidden_states,
  611. interpolate_pos_encoding=interpolate_pos_encoding,
  612. return_dict=return_dict,
  613. )
  614. sequence_output = outputs[0]
  615. # Reshape to (batch_size, num_channels, height, width)
  616. sequence_output = sequence_output[:, 1:]
  617. batch_size, sequence_length, num_channels = sequence_output.shape
  618. height = width = math.floor(sequence_length**0.5)
  619. sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width)
  620. # Reconstruct pixel values
  621. reconstructed_pixel_values = self.decoder(sequence_output)
  622. masked_im_loss = None
  623. if bool_masked_pos is not None:
  624. size = self.config.image_size // self.config.patch_size
  625. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  626. mask = (
  627. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  628. .repeat_interleave(self.config.patch_size, 2)
  629. .unsqueeze(1)
  630. .contiguous()
  631. )
  632. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  633. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  634. if not return_dict:
  635. output = (reconstructed_pixel_values,) + outputs[1:]
  636. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  637. return MaskedImageModelingOutput(
  638. loss=masked_im_loss,
  639. reconstruction=reconstructed_pixel_values,
  640. hidden_states=outputs.hidden_states,
  641. attentions=outputs.attentions,
  642. )
  643. @add_start_docstrings(
  644. """
  645. ViT Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  646. the [CLS] token) e.g. for ImageNet.
  647. <Tip>
  648. Note that it's possible to fine-tune ViT on higher resolution images than the ones it has been trained on, by
  649. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  650. position embeddings to the higher resolution.
  651. </Tip>
  652. """,
  653. VIT_START_DOCSTRING,
  654. )
  655. class ViTForImageClassification(ViTPreTrainedModel):
  656. def __init__(self, config: ViTConfig) -> None:
  657. super().__init__(config)
  658. self.num_labels = config.num_labels
  659. self.vit = ViTModel(config, add_pooling_layer=False)
  660. # Classifier head
  661. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  662. # Initialize weights and apply final processing
  663. self.post_init()
  664. @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING)
  665. @add_code_sample_docstrings(
  666. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  667. output_type=ImageClassifierOutput,
  668. config_class=_CONFIG_FOR_DOC,
  669. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  670. )
  671. def forward(
  672. self,
  673. pixel_values: Optional[torch.Tensor] = None,
  674. head_mask: Optional[torch.Tensor] = None,
  675. labels: Optional[torch.Tensor] = None,
  676. output_attentions: Optional[bool] = None,
  677. output_hidden_states: Optional[bool] = None,
  678. interpolate_pos_encoding: Optional[bool] = None,
  679. return_dict: Optional[bool] = None,
  680. ) -> Union[tuple, ImageClassifierOutput]:
  681. r"""
  682. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  683. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  684. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  685. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  686. """
  687. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  688. outputs = self.vit(
  689. pixel_values,
  690. head_mask=head_mask,
  691. output_attentions=output_attentions,
  692. output_hidden_states=output_hidden_states,
  693. interpolate_pos_encoding=interpolate_pos_encoding,
  694. return_dict=return_dict,
  695. )
  696. sequence_output = outputs[0]
  697. logits = self.classifier(sequence_output[:, 0, :])
  698. loss = None
  699. if labels is not None:
  700. # move labels to correct device to enable model parallelism
  701. labels = labels.to(logits.device)
  702. if self.config.problem_type is None:
  703. if self.num_labels == 1:
  704. self.config.problem_type = "regression"
  705. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  706. self.config.problem_type = "single_label_classification"
  707. else:
  708. self.config.problem_type = "multi_label_classification"
  709. if self.config.problem_type == "regression":
  710. loss_fct = MSELoss()
  711. if self.num_labels == 1:
  712. loss = loss_fct(logits.squeeze(), labels.squeeze())
  713. else:
  714. loss = loss_fct(logits, labels)
  715. elif self.config.problem_type == "single_label_classification":
  716. loss_fct = CrossEntropyLoss()
  717. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  718. elif self.config.problem_type == "multi_label_classification":
  719. loss_fct = BCEWithLogitsLoss()
  720. loss = loss_fct(logits, labels)
  721. if not return_dict:
  722. output = (logits,) + outputs[1:]
  723. return ((loss,) + output) if loss is not None else output
  724. return ImageClassifierOutput(
  725. loss=loss,
  726. logits=logits,
  727. hidden_states=outputs.hidden_states,
  728. attentions=outputs.attentions,
  729. )