modeling_van.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. # coding=utf-8
  2. # Copyright 2022 BNRist (Tsinghua University), TKLNDST (Nankai University) and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Visual Attention Network (VAN) model."""
  16. import math
  17. from collections import OrderedDict
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ....activations import ACT2FN
  24. from ....modeling_outputs import (
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. )
  29. from ....modeling_utils import PreTrainedModel
  30. from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  31. from .configuration_van import VanConfig
  32. logger = logging.get_logger(__name__)
  33. # General docstring
  34. _CONFIG_FOR_DOC = "VanConfig"
  35. # Base docstring
  36. _CHECKPOINT_FOR_DOC = "Visual-Attention-Network/van-base"
  37. _EXPECTED_OUTPUT_SHAPE = [1, 512, 7, 7]
  38. # Image classification docstring
  39. _IMAGE_CLASS_CHECKPOINT = "Visual-Attention-Network/van-base"
  40. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  41. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  42. """
  43. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  44. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  45. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  46. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  47. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  48. argument.
  49. """
  50. if drop_prob == 0.0 or not training:
  51. return input
  52. keep_prob = 1 - drop_prob
  53. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  54. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  55. random_tensor.floor_() # binarize
  56. output = input.div(keep_prob) * random_tensor
  57. return output
  58. class VanDropPath(nn.Module):
  59. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  60. def __init__(self, drop_prob: Optional[float] = None) -> None:
  61. super().__init__()
  62. self.drop_prob = drop_prob
  63. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  64. return drop_path(hidden_states, self.drop_prob, self.training)
  65. def extra_repr(self) -> str:
  66. return "p={}".format(self.drop_prob)
  67. class VanOverlappingPatchEmbedder(nn.Module):
  68. """
  69. Downsamples the input using a patchify operation with a `stride` of 4 by default making adjacent windows overlap by
  70. half of the area. From [PVTv2: Improved Baselines with Pyramid Vision
  71. Transformer](https://arxiv.org/abs/2106.13797).
  72. """
  73. def __init__(self, in_channels: int, hidden_size: int, patch_size: int = 7, stride: int = 4):
  74. super().__init__()
  75. self.convolution = nn.Conv2d(
  76. in_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2
  77. )
  78. self.normalization = nn.BatchNorm2d(hidden_size)
  79. def forward(self, input: torch.Tensor) -> torch.Tensor:
  80. hidden_state = self.convolution(input)
  81. hidden_state = self.normalization(hidden_state)
  82. return hidden_state
  83. class VanMlpLayer(nn.Module):
  84. """
  85. MLP with depth-wise convolution, from [PVTv2: Improved Baselines with Pyramid Vision
  86. Transformer](https://arxiv.org/abs/2106.13797).
  87. """
  88. def __init__(
  89. self,
  90. in_channels: int,
  91. hidden_size: int,
  92. out_channels: int,
  93. hidden_act: str = "gelu",
  94. dropout_rate: float = 0.5,
  95. ):
  96. super().__init__()
  97. self.in_dense = nn.Conv2d(in_channels, hidden_size, kernel_size=1)
  98. self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1, groups=hidden_size)
  99. self.activation = ACT2FN[hidden_act]
  100. self.dropout1 = nn.Dropout(dropout_rate)
  101. self.out_dense = nn.Conv2d(hidden_size, out_channels, kernel_size=1)
  102. self.dropout2 = nn.Dropout(dropout_rate)
  103. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  104. hidden_state = self.in_dense(hidden_state)
  105. hidden_state = self.depth_wise(hidden_state)
  106. hidden_state = self.activation(hidden_state)
  107. hidden_state = self.dropout1(hidden_state)
  108. hidden_state = self.out_dense(hidden_state)
  109. hidden_state = self.dropout2(hidden_state)
  110. return hidden_state
  111. class VanLargeKernelAttention(nn.Module):
  112. """
  113. Basic Large Kernel Attention (LKA).
  114. """
  115. def __init__(self, hidden_size: int):
  116. super().__init__()
  117. self.depth_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=5, padding=2, groups=hidden_size)
  118. self.depth_wise_dilated = nn.Conv2d(
  119. hidden_size, hidden_size, kernel_size=7, dilation=3, padding=9, groups=hidden_size
  120. )
  121. self.point_wise = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
  122. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  123. hidden_state = self.depth_wise(hidden_state)
  124. hidden_state = self.depth_wise_dilated(hidden_state)
  125. hidden_state = self.point_wise(hidden_state)
  126. return hidden_state
  127. class VanLargeKernelAttentionLayer(nn.Module):
  128. """
  129. Computes attention using Large Kernel Attention (LKA) and attends the input.
  130. """
  131. def __init__(self, hidden_size: int):
  132. super().__init__()
  133. self.attention = VanLargeKernelAttention(hidden_size)
  134. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  135. attention = self.attention(hidden_state)
  136. attended = hidden_state * attention
  137. return attended
  138. class VanSpatialAttentionLayer(nn.Module):
  139. """
  140. Van spatial attention layer composed by projection (via conv) -> act -> Large Kernel Attention (LKA) attention ->
  141. projection (via conv) + residual connection.
  142. """
  143. def __init__(self, hidden_size: int, hidden_act: str = "gelu"):
  144. super().__init__()
  145. self.pre_projection = nn.Sequential(
  146. OrderedDict(
  147. [
  148. ("conv", nn.Conv2d(hidden_size, hidden_size, kernel_size=1)),
  149. ("act", ACT2FN[hidden_act]),
  150. ]
  151. )
  152. )
  153. self.attention_layer = VanLargeKernelAttentionLayer(hidden_size)
  154. self.post_projection = nn.Conv2d(hidden_size, hidden_size, kernel_size=1)
  155. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  156. residual = hidden_state
  157. hidden_state = self.pre_projection(hidden_state)
  158. hidden_state = self.attention_layer(hidden_state)
  159. hidden_state = self.post_projection(hidden_state)
  160. hidden_state = hidden_state + residual
  161. return hidden_state
  162. class VanLayerScaling(nn.Module):
  163. """
  164. Scales the inputs by a learnable parameter initialized by `initial_value`.
  165. """
  166. def __init__(self, hidden_size: int, initial_value: float = 1e-2):
  167. super().__init__()
  168. self.weight = nn.Parameter(initial_value * torch.ones((hidden_size)), requires_grad=True)
  169. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  170. # unsqueezing for broadcasting
  171. hidden_state = self.weight.unsqueeze(-1).unsqueeze(-1) * hidden_state
  172. return hidden_state
  173. class VanLayer(nn.Module):
  174. """
  175. Van layer composed by normalization layers, large kernel attention (LKA) and a multi layer perceptron (MLP).
  176. """
  177. def __init__(
  178. self,
  179. config: VanConfig,
  180. hidden_size: int,
  181. mlp_ratio: int = 4,
  182. drop_path_rate: float = 0.5,
  183. ):
  184. super().__init__()
  185. self.drop_path = VanDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  186. self.pre_normomalization = nn.BatchNorm2d(hidden_size)
  187. self.attention = VanSpatialAttentionLayer(hidden_size, config.hidden_act)
  188. self.attention_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
  189. self.post_normalization = nn.BatchNorm2d(hidden_size)
  190. self.mlp = VanMlpLayer(
  191. hidden_size, hidden_size * mlp_ratio, hidden_size, config.hidden_act, config.dropout_rate
  192. )
  193. self.mlp_scaling = VanLayerScaling(hidden_size, config.layer_scale_init_value)
  194. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  195. residual = hidden_state
  196. # attention
  197. hidden_state = self.pre_normomalization(hidden_state)
  198. hidden_state = self.attention(hidden_state)
  199. hidden_state = self.attention_scaling(hidden_state)
  200. hidden_state = self.drop_path(hidden_state)
  201. # residual connection
  202. hidden_state = residual + hidden_state
  203. residual = hidden_state
  204. # mlp
  205. hidden_state = self.post_normalization(hidden_state)
  206. hidden_state = self.mlp(hidden_state)
  207. hidden_state = self.mlp_scaling(hidden_state)
  208. hidden_state = self.drop_path(hidden_state)
  209. # residual connection
  210. hidden_state = residual + hidden_state
  211. return hidden_state
  212. class VanStage(nn.Module):
  213. """
  214. VanStage, consisting of multiple layers.
  215. """
  216. def __init__(
  217. self,
  218. config: VanConfig,
  219. in_channels: int,
  220. hidden_size: int,
  221. patch_size: int,
  222. stride: int,
  223. depth: int,
  224. mlp_ratio: int = 4,
  225. drop_path_rate: float = 0.0,
  226. ):
  227. super().__init__()
  228. self.embeddings = VanOverlappingPatchEmbedder(in_channels, hidden_size, patch_size, stride)
  229. self.layers = nn.Sequential(
  230. *[
  231. VanLayer(
  232. config,
  233. hidden_size,
  234. mlp_ratio=mlp_ratio,
  235. drop_path_rate=drop_path_rate,
  236. )
  237. for _ in range(depth)
  238. ]
  239. )
  240. self.normalization = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  241. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  242. hidden_state = self.embeddings(hidden_state)
  243. hidden_state = self.layers(hidden_state)
  244. # rearrange b c h w -> b (h w) c
  245. batch_size, hidden_size, height, width = hidden_state.shape
  246. hidden_state = hidden_state.flatten(2).transpose(1, 2)
  247. hidden_state = self.normalization(hidden_state)
  248. # rearrange b (h w) c- > b c h w
  249. hidden_state = hidden_state.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2)
  250. return hidden_state
  251. class VanEncoder(nn.Module):
  252. """
  253. VanEncoder, consisting of multiple stages.
  254. """
  255. def __init__(self, config: VanConfig):
  256. super().__init__()
  257. self.stages = nn.ModuleList([])
  258. patch_sizes = config.patch_sizes
  259. strides = config.strides
  260. hidden_sizes = config.hidden_sizes
  261. depths = config.depths
  262. mlp_ratios = config.mlp_ratios
  263. drop_path_rates = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))]
  264. for num_stage, (patch_size, stride, hidden_size, depth, mlp_expantion, drop_path_rate) in enumerate(
  265. zip(patch_sizes, strides, hidden_sizes, depths, mlp_ratios, drop_path_rates)
  266. ):
  267. is_first_stage = num_stage == 0
  268. in_channels = hidden_sizes[num_stage - 1]
  269. if is_first_stage:
  270. in_channels = config.num_channels
  271. self.stages.append(
  272. VanStage(
  273. config,
  274. in_channels,
  275. hidden_size,
  276. patch_size=patch_size,
  277. stride=stride,
  278. depth=depth,
  279. mlp_ratio=mlp_expantion,
  280. drop_path_rate=drop_path_rate,
  281. )
  282. )
  283. def forward(
  284. self,
  285. hidden_state: torch.Tensor,
  286. output_hidden_states: Optional[bool] = False,
  287. return_dict: Optional[bool] = True,
  288. ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
  289. all_hidden_states = () if output_hidden_states else None
  290. for _, stage_module in enumerate(self.stages):
  291. hidden_state = stage_module(hidden_state)
  292. if output_hidden_states:
  293. all_hidden_states = all_hidden_states + (hidden_state,)
  294. if not return_dict:
  295. return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
  296. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
  297. class VanPreTrainedModel(PreTrainedModel):
  298. """
  299. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  300. models.
  301. """
  302. config_class = VanConfig
  303. base_model_prefix = "van"
  304. main_input_name = "pixel_values"
  305. supports_gradient_checkpointing = True
  306. def _init_weights(self, module):
  307. """Initialize the weights"""
  308. if isinstance(module, nn.Linear):
  309. nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
  310. if isinstance(module, nn.Linear) and module.bias is not None:
  311. nn.init.constant_(module.bias, 0)
  312. elif isinstance(module, nn.LayerNorm):
  313. nn.init.constant_(module.bias, 0)
  314. nn.init.constant_(module.weight, 1.0)
  315. elif isinstance(module, nn.Conv2d):
  316. fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
  317. fan_out //= module.groups
  318. module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
  319. if module.bias is not None:
  320. module.bias.data.zero_()
  321. VAN_START_DOCSTRING = r"""
  322. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  323. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  324. behavior.
  325. Parameters:
  326. config ([`VanConfig`]): Model configuration class with all the parameters of the model.
  327. Initializing with a config file does not load the weights associated with the model, only the
  328. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  329. """
  330. VAN_INPUTS_DOCSTRING = r"""
  331. Args:
  332. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  333. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  334. [`ConvNextImageProcessor.__call__`] for details.
  335. output_hidden_states (`bool`, *optional*):
  336. Whether or not to return the hidden states of all stages. See `hidden_states` under returned tensors for
  337. more detail.
  338. return_dict (`bool`, *optional*):
  339. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  340. """
  341. @add_start_docstrings(
  342. "The bare VAN model outputting raw features without any specific head on top. Note, VAN does not have an embedding"
  343. " layer.",
  344. VAN_START_DOCSTRING,
  345. )
  346. class VanModel(VanPreTrainedModel):
  347. def __init__(self, config):
  348. super().__init__(config)
  349. self.config = config
  350. self.encoder = VanEncoder(config)
  351. # final layernorm layer
  352. self.layernorm = nn.LayerNorm(config.hidden_sizes[-1], eps=config.layer_norm_eps)
  353. # Initialize weights and apply final processing
  354. self.post_init()
  355. @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
  356. @add_code_sample_docstrings(
  357. checkpoint=_CHECKPOINT_FOR_DOC,
  358. output_type=BaseModelOutputWithPoolingAndNoAttention,
  359. config_class=_CONFIG_FOR_DOC,
  360. modality="vision",
  361. expected_output=_EXPECTED_OUTPUT_SHAPE,
  362. )
  363. def forward(
  364. self,
  365. pixel_values: Optional[torch.FloatTensor],
  366. output_hidden_states: Optional[bool] = None,
  367. return_dict: Optional[bool] = None,
  368. ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
  369. output_hidden_states = (
  370. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  371. )
  372. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  373. encoder_outputs = self.encoder(
  374. pixel_values,
  375. output_hidden_states=output_hidden_states,
  376. return_dict=return_dict,
  377. )
  378. last_hidden_state = encoder_outputs[0]
  379. # global average pooling, n c w h -> n c
  380. pooled_output = last_hidden_state.mean(dim=[-2, -1])
  381. if not return_dict:
  382. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  383. return BaseModelOutputWithPoolingAndNoAttention(
  384. last_hidden_state=last_hidden_state,
  385. pooler_output=pooled_output,
  386. hidden_states=encoder_outputs.hidden_states,
  387. )
  388. @add_start_docstrings(
  389. """
  390. VAN Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  391. ImageNet.
  392. """,
  393. VAN_START_DOCSTRING,
  394. )
  395. class VanForImageClassification(VanPreTrainedModel):
  396. def __init__(self, config):
  397. super().__init__(config)
  398. self.van = VanModel(config)
  399. # Classifier head
  400. self.classifier = (
  401. nn.Linear(config.hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  402. )
  403. # Initialize weights and apply final processing
  404. self.post_init()
  405. @add_start_docstrings_to_model_forward(VAN_INPUTS_DOCSTRING)
  406. @add_code_sample_docstrings(
  407. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  408. output_type=ImageClassifierOutputWithNoAttention,
  409. config_class=_CONFIG_FOR_DOC,
  410. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  411. )
  412. def forward(
  413. self,
  414. pixel_values: Optional[torch.FloatTensor] = None,
  415. labels: Optional[torch.LongTensor] = None,
  416. output_hidden_states: Optional[bool] = None,
  417. return_dict: Optional[bool] = None,
  418. ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
  419. r"""
  420. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  421. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  422. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  423. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  424. """
  425. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  426. outputs = self.van(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  427. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  428. logits = self.classifier(pooled_output)
  429. loss = None
  430. if labels is not None:
  431. if self.config.problem_type is None:
  432. if self.config.num_labels == 1:
  433. self.config.problem_type = "regression"
  434. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  435. self.config.problem_type = "single_label_classification"
  436. else:
  437. self.config.problem_type = "multi_label_classification"
  438. if self.config.problem_type == "regression":
  439. loss_fct = MSELoss()
  440. if self.config.num_labels == 1:
  441. loss = loss_fct(logits.squeeze(), labels.squeeze())
  442. else:
  443. loss = loss_fct(logits, labels)
  444. elif self.config.problem_type == "single_label_classification":
  445. loss_fct = CrossEntropyLoss()
  446. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  447. elif self.config.problem_type == "multi_label_classification":
  448. loss_fct = BCEWithLogitsLoss()
  449. loss = loss_fct(logits, labels)
  450. if not return_dict:
  451. output = (logits,) + outputs[2:]
  452. return ((loss,) + output) if loss is not None else output
  453. return ImageClassifierOutputWithNoAttention(loss=loss, logits=logits, hidden_states=outputs.hidden_states)