modeling_mobilevitv2.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. # coding=utf-8
  2. # Copyright 2023 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  17. """PyTorch MobileViTV2 model."""
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithNoAttention,
  26. BaseModelOutputWithPoolingAndNoAttention,
  27. ImageClassifierOutputWithNoAttention,
  28. SemanticSegmenterOutput,
  29. )
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import (
  32. add_code_sample_docstrings,
  33. add_start_docstrings,
  34. add_start_docstrings_to_model_forward,
  35. logging,
  36. replace_return_docstrings,
  37. )
  38. from .configuration_mobilevitv2 import MobileViTV2Config
  39. logger = logging.get_logger(__name__)
  40. # General docstring
  41. _CONFIG_FOR_DOC = "MobileViTV2Config"
  42. # Base docstring
  43. _CHECKPOINT_FOR_DOC = "apple/mobilevitv2-1.0-imagenet1k-256"
  44. _EXPECTED_OUTPUT_SHAPE = [1, 512, 8, 8]
  45. # Image classification docstring
  46. _IMAGE_CLASS_CHECKPOINT = "apple/mobilevitv2-1.0-imagenet1k-256"
  47. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  48. # Copied from transformers.models.mobilevit.modeling_mobilevit.make_divisible
  49. def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
  50. """
  51. Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
  52. original TensorFlow repo. It can be seen here:
  53. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  54. """
  55. if min_value is None:
  56. min_value = divisor
  57. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  58. # Make sure that round down does not go down by more than 10%.
  59. if new_value < 0.9 * value:
  60. new_value += divisor
  61. return int(new_value)
  62. def clip(value: float, min_val: float = float("-inf"), max_val: float = float("inf")) -> float:
  63. return max(min_val, min(max_val, value))
  64. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTConvLayer with MobileViT->MobileViTV2
  65. class MobileViTV2ConvLayer(nn.Module):
  66. def __init__(
  67. self,
  68. config: MobileViTV2Config,
  69. in_channels: int,
  70. out_channels: int,
  71. kernel_size: int,
  72. stride: int = 1,
  73. groups: int = 1,
  74. bias: bool = False,
  75. dilation: int = 1,
  76. use_normalization: bool = True,
  77. use_activation: Union[bool, str] = True,
  78. ) -> None:
  79. super().__init__()
  80. padding = int((kernel_size - 1) / 2) * dilation
  81. if in_channels % groups != 0:
  82. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  83. if out_channels % groups != 0:
  84. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  85. self.convolution = nn.Conv2d(
  86. in_channels=in_channels,
  87. out_channels=out_channels,
  88. kernel_size=kernel_size,
  89. stride=stride,
  90. padding=padding,
  91. dilation=dilation,
  92. groups=groups,
  93. bias=bias,
  94. padding_mode="zeros",
  95. )
  96. if use_normalization:
  97. self.normalization = nn.BatchNorm2d(
  98. num_features=out_channels,
  99. eps=1e-5,
  100. momentum=0.1,
  101. affine=True,
  102. track_running_stats=True,
  103. )
  104. else:
  105. self.normalization = None
  106. if use_activation:
  107. if isinstance(use_activation, str):
  108. self.activation = ACT2FN[use_activation]
  109. elif isinstance(config.hidden_act, str):
  110. self.activation = ACT2FN[config.hidden_act]
  111. else:
  112. self.activation = config.hidden_act
  113. else:
  114. self.activation = None
  115. def forward(self, features: torch.Tensor) -> torch.Tensor:
  116. features = self.convolution(features)
  117. if self.normalization is not None:
  118. features = self.normalization(features)
  119. if self.activation is not None:
  120. features = self.activation(features)
  121. return features
  122. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTInvertedResidual with MobileViT->MobileViTV2
  123. class MobileViTV2InvertedResidual(nn.Module):
  124. """
  125. Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
  126. """
  127. def __init__(
  128. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  129. ) -> None:
  130. super().__init__()
  131. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  132. if stride not in [1, 2]:
  133. raise ValueError(f"Invalid stride {stride}.")
  134. self.use_residual = (stride == 1) and (in_channels == out_channels)
  135. self.expand_1x1 = MobileViTV2ConvLayer(
  136. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  137. )
  138. self.conv_3x3 = MobileViTV2ConvLayer(
  139. config,
  140. in_channels=expanded_channels,
  141. out_channels=expanded_channels,
  142. kernel_size=3,
  143. stride=stride,
  144. groups=expanded_channels,
  145. dilation=dilation,
  146. )
  147. self.reduce_1x1 = MobileViTV2ConvLayer(
  148. config,
  149. in_channels=expanded_channels,
  150. out_channels=out_channels,
  151. kernel_size=1,
  152. use_activation=False,
  153. )
  154. def forward(self, features: torch.Tensor) -> torch.Tensor:
  155. residual = features
  156. features = self.expand_1x1(features)
  157. features = self.conv_3x3(features)
  158. features = self.reduce_1x1(features)
  159. return residual + features if self.use_residual else features
  160. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTMobileNetLayer with MobileViT->MobileViTV2
  161. class MobileViTV2MobileNetLayer(nn.Module):
  162. def __init__(
  163. self, config: MobileViTV2Config, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  164. ) -> None:
  165. super().__init__()
  166. self.layer = nn.ModuleList()
  167. for i in range(num_stages):
  168. layer = MobileViTV2InvertedResidual(
  169. config,
  170. in_channels=in_channels,
  171. out_channels=out_channels,
  172. stride=stride if i == 0 else 1,
  173. )
  174. self.layer.append(layer)
  175. in_channels = out_channels
  176. def forward(self, features: torch.Tensor) -> torch.Tensor:
  177. for layer_module in self.layer:
  178. features = layer_module(features)
  179. return features
  180. class MobileViTV2LinearSelfAttention(nn.Module):
  181. """
  182. This layer applies a self-attention with linear complexity, as described in MobileViTV2 paper:
  183. https://arxiv.org/abs/2206.02680
  184. Args:
  185. config (`MobileVitv2Config`):
  186. Model configuration object
  187. embed_dim (`int`):
  188. `input_channels` from an expected input of size :math:`(batch_size, input_channels, height, width)`
  189. """
  190. def __init__(self, config: MobileViTV2Config, embed_dim: int) -> None:
  191. super().__init__()
  192. self.qkv_proj = MobileViTV2ConvLayer(
  193. config=config,
  194. in_channels=embed_dim,
  195. out_channels=1 + (2 * embed_dim),
  196. bias=True,
  197. kernel_size=1,
  198. use_normalization=False,
  199. use_activation=False,
  200. )
  201. self.attn_dropout = nn.Dropout(p=config.attn_dropout)
  202. self.out_proj = MobileViTV2ConvLayer(
  203. config=config,
  204. in_channels=embed_dim,
  205. out_channels=embed_dim,
  206. bias=True,
  207. kernel_size=1,
  208. use_normalization=False,
  209. use_activation=False,
  210. )
  211. self.embed_dim = embed_dim
  212. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  213. # (batch_size, embed_dim, num_pixels_in_patch, num_patches) --> (batch_size, 1+2*embed_dim, num_pixels_in_patch, num_patches)
  214. qkv = self.qkv_proj(hidden_states)
  215. # Project hidden_states into query, key and value
  216. # Query --> [batch_size, 1, num_pixels_in_patch, num_patches]
  217. # value, key --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  218. query, key, value = torch.split(qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1)
  219. # apply softmax along num_patches dimension
  220. context_scores = torch.nn.functional.softmax(query, dim=-1)
  221. context_scores = self.attn_dropout(context_scores)
  222. # Compute context vector
  223. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] x [batch_size, 1, num_pixels_in_patch, num_patches] -> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  224. context_vector = key * context_scores
  225. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] --> [batch_size, embed_dim, num_pixels_in_patch, 1]
  226. context_vector = torch.sum(context_vector, dim=-1, keepdim=True)
  227. # combine context vector with values
  228. # [batch_size, embed_dim, num_pixels_in_patch, num_patches] * [batch_size, embed_dim, num_pixels_in_patch, 1] --> [batch_size, embed_dim, num_pixels_in_patch, num_patches]
  229. out = torch.nn.functional.relu(value) * context_vector.expand_as(value)
  230. out = self.out_proj(out)
  231. return out
  232. class MobileViTV2FFN(nn.Module):
  233. def __init__(
  234. self,
  235. config: MobileViTV2Config,
  236. embed_dim: int,
  237. ffn_latent_dim: int,
  238. ffn_dropout: float = 0.0,
  239. ) -> None:
  240. super().__init__()
  241. self.conv1 = MobileViTV2ConvLayer(
  242. config=config,
  243. in_channels=embed_dim,
  244. out_channels=ffn_latent_dim,
  245. kernel_size=1,
  246. stride=1,
  247. bias=True,
  248. use_normalization=False,
  249. use_activation=True,
  250. )
  251. self.dropout1 = nn.Dropout(ffn_dropout)
  252. self.conv2 = MobileViTV2ConvLayer(
  253. config=config,
  254. in_channels=ffn_latent_dim,
  255. out_channels=embed_dim,
  256. kernel_size=1,
  257. stride=1,
  258. bias=True,
  259. use_normalization=False,
  260. use_activation=False,
  261. )
  262. self.dropout2 = nn.Dropout(ffn_dropout)
  263. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  264. hidden_states = self.conv1(hidden_states)
  265. hidden_states = self.dropout1(hidden_states)
  266. hidden_states = self.conv2(hidden_states)
  267. hidden_states = self.dropout2(hidden_states)
  268. return hidden_states
  269. class MobileViTV2TransformerLayer(nn.Module):
  270. def __init__(
  271. self,
  272. config: MobileViTV2Config,
  273. embed_dim: int,
  274. ffn_latent_dim: int,
  275. dropout: float = 0.0,
  276. ) -> None:
  277. super().__init__()
  278. self.layernorm_before = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  279. self.attention = MobileViTV2LinearSelfAttention(config, embed_dim)
  280. self.dropout1 = nn.Dropout(p=dropout)
  281. self.layernorm_after = nn.GroupNorm(num_groups=1, num_channels=embed_dim, eps=config.layer_norm_eps)
  282. self.ffn = MobileViTV2FFN(config, embed_dim, ffn_latent_dim, config.ffn_dropout)
  283. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  284. layernorm_1_out = self.layernorm_before(hidden_states)
  285. attention_output = self.attention(layernorm_1_out)
  286. hidden_states = attention_output + hidden_states
  287. layer_output = self.layernorm_after(hidden_states)
  288. layer_output = self.ffn(layer_output)
  289. layer_output = layer_output + hidden_states
  290. return layer_output
  291. class MobileViTV2Transformer(nn.Module):
  292. def __init__(self, config: MobileViTV2Config, n_layers: int, d_model: int) -> None:
  293. super().__init__()
  294. ffn_multiplier = config.ffn_multiplier
  295. ffn_dims = [ffn_multiplier * d_model] * n_layers
  296. # ensure that dims are multiple of 16
  297. ffn_dims = [int((d // 16) * 16) for d in ffn_dims]
  298. self.layer = nn.ModuleList()
  299. for block_idx in range(n_layers):
  300. transformer_layer = MobileViTV2TransformerLayer(
  301. config, embed_dim=d_model, ffn_latent_dim=ffn_dims[block_idx]
  302. )
  303. self.layer.append(transformer_layer)
  304. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  305. for layer_module in self.layer:
  306. hidden_states = layer_module(hidden_states)
  307. return hidden_states
  308. class MobileViTV2Layer(nn.Module):
  309. """
  310. MobileViTV2 layer: https://arxiv.org/abs/2206.02680
  311. """
  312. def __init__(
  313. self,
  314. config: MobileViTV2Config,
  315. in_channels: int,
  316. out_channels: int,
  317. attn_unit_dim: int,
  318. n_attn_blocks: int = 2,
  319. dilation: int = 1,
  320. stride: int = 2,
  321. ) -> None:
  322. super().__init__()
  323. self.patch_width = config.patch_size
  324. self.patch_height = config.patch_size
  325. cnn_out_dim = attn_unit_dim
  326. if stride == 2:
  327. self.downsampling_layer = MobileViTV2InvertedResidual(
  328. config,
  329. in_channels=in_channels,
  330. out_channels=out_channels,
  331. stride=stride if dilation == 1 else 1,
  332. dilation=dilation // 2 if dilation > 1 else 1,
  333. )
  334. in_channels = out_channels
  335. else:
  336. self.downsampling_layer = None
  337. # Local representations
  338. self.conv_kxk = MobileViTV2ConvLayer(
  339. config,
  340. in_channels=in_channels,
  341. out_channels=in_channels,
  342. kernel_size=config.conv_kernel_size,
  343. groups=in_channels,
  344. )
  345. self.conv_1x1 = MobileViTV2ConvLayer(
  346. config,
  347. in_channels=in_channels,
  348. out_channels=cnn_out_dim,
  349. kernel_size=1,
  350. use_normalization=False,
  351. use_activation=False,
  352. )
  353. # Global representations
  354. self.transformer = MobileViTV2Transformer(config, d_model=attn_unit_dim, n_layers=n_attn_blocks)
  355. # self.layernorm = MobileViTV2LayerNorm2D(attn_unit_dim, eps=config.layer_norm_eps)
  356. self.layernorm = nn.GroupNorm(num_groups=1, num_channels=attn_unit_dim, eps=config.layer_norm_eps)
  357. # Fusion
  358. self.conv_projection = MobileViTV2ConvLayer(
  359. config,
  360. in_channels=cnn_out_dim,
  361. out_channels=in_channels,
  362. kernel_size=1,
  363. use_normalization=True,
  364. use_activation=False,
  365. )
  366. def unfolding(self, feature_map: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
  367. batch_size, in_channels, img_height, img_width = feature_map.shape
  368. patches = nn.functional.unfold(
  369. feature_map,
  370. kernel_size=(self.patch_height, self.patch_width),
  371. stride=(self.patch_height, self.patch_width),
  372. )
  373. patches = patches.reshape(batch_size, in_channels, self.patch_height * self.patch_width, -1)
  374. return patches, (img_height, img_width)
  375. def folding(self, patches: torch.Tensor, output_size: Tuple[int, int]) -> torch.Tensor:
  376. batch_size, in_dim, patch_size, n_patches = patches.shape
  377. patches = patches.reshape(batch_size, in_dim * patch_size, n_patches)
  378. feature_map = nn.functional.fold(
  379. patches,
  380. output_size=output_size,
  381. kernel_size=(self.patch_height, self.patch_width),
  382. stride=(self.patch_height, self.patch_width),
  383. )
  384. return feature_map
  385. def forward(self, features: torch.Tensor) -> torch.Tensor:
  386. # reduce spatial dimensions if needed
  387. if self.downsampling_layer:
  388. features = self.downsampling_layer(features)
  389. # local representation
  390. features = self.conv_kxk(features)
  391. features = self.conv_1x1(features)
  392. # convert feature map to patches
  393. patches, output_size = self.unfolding(features)
  394. # learn global representations
  395. patches = self.transformer(patches)
  396. patches = self.layernorm(patches)
  397. # convert patches back to feature maps
  398. # [batch_size, patch_height, patch_width, input_dim] --> [batch_size, input_dim, patch_height, patch_width]
  399. features = self.folding(patches, output_size)
  400. features = self.conv_projection(features)
  401. return features
  402. class MobileViTV2Encoder(nn.Module):
  403. def __init__(self, config: MobileViTV2Config) -> None:
  404. super().__init__()
  405. self.config = config
  406. self.layer = nn.ModuleList()
  407. self.gradient_checkpointing = False
  408. # segmentation architectures like DeepLab and PSPNet modify the strides
  409. # of the classification backbones
  410. dilate_layer_4 = dilate_layer_5 = False
  411. if config.output_stride == 8:
  412. dilate_layer_4 = True
  413. dilate_layer_5 = True
  414. elif config.output_stride == 16:
  415. dilate_layer_5 = True
  416. dilation = 1
  417. layer_0_dim = make_divisible(
  418. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  419. )
  420. layer_1_dim = make_divisible(64 * config.width_multiplier, divisor=16)
  421. layer_2_dim = make_divisible(128 * config.width_multiplier, divisor=8)
  422. layer_3_dim = make_divisible(256 * config.width_multiplier, divisor=8)
  423. layer_4_dim = make_divisible(384 * config.width_multiplier, divisor=8)
  424. layer_5_dim = make_divisible(512 * config.width_multiplier, divisor=8)
  425. layer_1 = MobileViTV2MobileNetLayer(
  426. config,
  427. in_channels=layer_0_dim,
  428. out_channels=layer_1_dim,
  429. stride=1,
  430. num_stages=1,
  431. )
  432. self.layer.append(layer_1)
  433. layer_2 = MobileViTV2MobileNetLayer(
  434. config,
  435. in_channels=layer_1_dim,
  436. out_channels=layer_2_dim,
  437. stride=2,
  438. num_stages=2,
  439. )
  440. self.layer.append(layer_2)
  441. layer_3 = MobileViTV2Layer(
  442. config,
  443. in_channels=layer_2_dim,
  444. out_channels=layer_3_dim,
  445. attn_unit_dim=make_divisible(config.base_attn_unit_dims[0] * config.width_multiplier, divisor=8),
  446. n_attn_blocks=config.n_attn_blocks[0],
  447. )
  448. self.layer.append(layer_3)
  449. if dilate_layer_4:
  450. dilation *= 2
  451. layer_4 = MobileViTV2Layer(
  452. config,
  453. in_channels=layer_3_dim,
  454. out_channels=layer_4_dim,
  455. attn_unit_dim=make_divisible(config.base_attn_unit_dims[1] * config.width_multiplier, divisor=8),
  456. n_attn_blocks=config.n_attn_blocks[1],
  457. dilation=dilation,
  458. )
  459. self.layer.append(layer_4)
  460. if dilate_layer_5:
  461. dilation *= 2
  462. layer_5 = MobileViTV2Layer(
  463. config,
  464. in_channels=layer_4_dim,
  465. out_channels=layer_5_dim,
  466. attn_unit_dim=make_divisible(config.base_attn_unit_dims[2] * config.width_multiplier, divisor=8),
  467. n_attn_blocks=config.n_attn_blocks[2],
  468. dilation=dilation,
  469. )
  470. self.layer.append(layer_5)
  471. def forward(
  472. self,
  473. hidden_states: torch.Tensor,
  474. output_hidden_states: bool = False,
  475. return_dict: bool = True,
  476. ) -> Union[tuple, BaseModelOutputWithNoAttention]:
  477. all_hidden_states = () if output_hidden_states else None
  478. for i, layer_module in enumerate(self.layer):
  479. if self.gradient_checkpointing and self.training:
  480. hidden_states = self._gradient_checkpointing_func(
  481. layer_module.__call__,
  482. hidden_states,
  483. )
  484. else:
  485. hidden_states = layer_module(hidden_states)
  486. if output_hidden_states:
  487. all_hidden_states = all_hidden_states + (hidden_states,)
  488. if not return_dict:
  489. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  490. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  491. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTPreTrainedModel with MobileViT->MobileViTV2,mobilevit->mobilevitv2
  492. class MobileViTV2PreTrainedModel(PreTrainedModel):
  493. """
  494. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  495. models.
  496. """
  497. config_class = MobileViTV2Config
  498. base_model_prefix = "mobilevitv2"
  499. main_input_name = "pixel_values"
  500. supports_gradient_checkpointing = True
  501. _no_split_modules = ["MobileViTV2Layer"]
  502. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  503. """Initialize the weights"""
  504. if isinstance(module, (nn.Linear, nn.Conv2d)):
  505. # Slightly different from the TF version which uses truncated_normal for initialization
  506. # cf https://github.com/pytorch/pytorch/pull/5617
  507. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  508. if module.bias is not None:
  509. module.bias.data.zero_()
  510. elif isinstance(module, nn.LayerNorm):
  511. module.bias.data.zero_()
  512. module.weight.data.fill_(1.0)
  513. MOBILEVITV2_START_DOCSTRING = r"""
  514. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  515. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  516. behavior.
  517. Parameters:
  518. config ([`MobileViTV2Config`]): Model configuration class with all the parameters of the model.
  519. Initializing with a config file does not load the weights associated with the model, only the
  520. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  521. """
  522. MOBILEVITV2_INPUTS_DOCSTRING = r"""
  523. Args:
  524. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  525. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  526. [`MobileViTImageProcessor.__call__`] for details.
  527. output_hidden_states (`bool`, *optional*):
  528. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  529. more detail.
  530. return_dict (`bool`, *optional*):
  531. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  532. """
  533. @add_start_docstrings(
  534. "The bare MobileViTV2 model outputting raw hidden-states without any specific head on top.",
  535. MOBILEVITV2_START_DOCSTRING,
  536. )
  537. class MobileViTV2Model(MobileViTV2PreTrainedModel):
  538. def __init__(self, config: MobileViTV2Config, expand_output: bool = True):
  539. super().__init__(config)
  540. self.config = config
  541. self.expand_output = expand_output
  542. layer_0_dim = make_divisible(
  543. clip(value=32 * config.width_multiplier, min_val=16, max_val=64), divisor=8, min_value=16
  544. )
  545. self.conv_stem = MobileViTV2ConvLayer(
  546. config,
  547. in_channels=config.num_channels,
  548. out_channels=layer_0_dim,
  549. kernel_size=3,
  550. stride=2,
  551. use_normalization=True,
  552. use_activation=True,
  553. )
  554. self.encoder = MobileViTV2Encoder(config)
  555. # Initialize weights and apply final processing
  556. self.post_init()
  557. def _prune_heads(self, heads_to_prune):
  558. """Prunes heads of the model.
  559. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
  560. """
  561. for layer_index, heads in heads_to_prune.items():
  562. mobilevitv2_layer = self.encoder.layer[layer_index]
  563. if isinstance(mobilevitv2_layer, MobileViTV2Layer):
  564. for transformer_layer in mobilevitv2_layer.transformer.layer:
  565. transformer_layer.attention.prune_heads(heads)
  566. @add_start_docstrings_to_model_forward(MOBILEVITV2_INPUTS_DOCSTRING)
  567. @add_code_sample_docstrings(
  568. checkpoint=_CHECKPOINT_FOR_DOC,
  569. output_type=BaseModelOutputWithPoolingAndNoAttention,
  570. config_class=_CONFIG_FOR_DOC,
  571. modality="vision",
  572. expected_output=_EXPECTED_OUTPUT_SHAPE,
  573. )
  574. def forward(
  575. self,
  576. pixel_values: Optional[torch.Tensor] = None,
  577. output_hidden_states: Optional[bool] = None,
  578. return_dict: Optional[bool] = None,
  579. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  580. output_hidden_states = (
  581. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  582. )
  583. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  584. if pixel_values is None:
  585. raise ValueError("You have to specify pixel_values")
  586. embedding_output = self.conv_stem(pixel_values)
  587. encoder_outputs = self.encoder(
  588. embedding_output,
  589. output_hidden_states=output_hidden_states,
  590. return_dict=return_dict,
  591. )
  592. if self.expand_output:
  593. last_hidden_state = encoder_outputs[0]
  594. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  595. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  596. else:
  597. last_hidden_state = encoder_outputs[0]
  598. pooled_output = None
  599. if not return_dict:
  600. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  601. return output + encoder_outputs[1:]
  602. return BaseModelOutputWithPoolingAndNoAttention(
  603. last_hidden_state=last_hidden_state,
  604. pooler_output=pooled_output,
  605. hidden_states=encoder_outputs.hidden_states,
  606. )
  607. @add_start_docstrings(
  608. """
  609. MobileViTV2 model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  610. ImageNet.
  611. """,
  612. MOBILEVITV2_START_DOCSTRING,
  613. )
  614. class MobileViTV2ForImageClassification(MobileViTV2PreTrainedModel):
  615. def __init__(self, config: MobileViTV2Config) -> None:
  616. super().__init__(config)
  617. self.num_labels = config.num_labels
  618. self.mobilevitv2 = MobileViTV2Model(config)
  619. out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  620. # Classifier head
  621. self.classifier = (
  622. nn.Linear(in_features=out_channels, out_features=config.num_labels)
  623. if config.num_labels > 0
  624. else nn.Identity()
  625. )
  626. # Initialize weights and apply final processing
  627. self.post_init()
  628. @add_start_docstrings_to_model_forward(MOBILEVITV2_INPUTS_DOCSTRING)
  629. @add_code_sample_docstrings(
  630. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  631. output_type=ImageClassifierOutputWithNoAttention,
  632. config_class=_CONFIG_FOR_DOC,
  633. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  634. )
  635. def forward(
  636. self,
  637. pixel_values: Optional[torch.Tensor] = None,
  638. output_hidden_states: Optional[bool] = None,
  639. labels: Optional[torch.Tensor] = None,
  640. return_dict: Optional[bool] = None,
  641. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  642. r"""
  643. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  644. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  645. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  646. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  647. """
  648. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  649. outputs = self.mobilevitv2(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  650. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  651. logits = self.classifier(pooled_output)
  652. loss = None
  653. if labels is not None:
  654. if self.config.problem_type is None:
  655. if self.num_labels == 1:
  656. self.config.problem_type = "regression"
  657. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  658. self.config.problem_type = "single_label_classification"
  659. else:
  660. self.config.problem_type = "multi_label_classification"
  661. if self.config.problem_type == "regression":
  662. loss_fct = MSELoss()
  663. if self.num_labels == 1:
  664. loss = loss_fct(logits.squeeze(), labels.squeeze())
  665. else:
  666. loss = loss_fct(logits, labels)
  667. elif self.config.problem_type == "single_label_classification":
  668. loss_fct = CrossEntropyLoss()
  669. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  670. elif self.config.problem_type == "multi_label_classification":
  671. loss_fct = BCEWithLogitsLoss()
  672. loss = loss_fct(logits, labels)
  673. if not return_dict:
  674. output = (logits,) + outputs[2:]
  675. return ((loss,) + output) if loss is not None else output
  676. return ImageClassifierOutputWithNoAttention(
  677. loss=loss,
  678. logits=logits,
  679. hidden_states=outputs.hidden_states,
  680. )
  681. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTASPPPooling with MobileViT->MobileViTV2
  682. class MobileViTV2ASPPPooling(nn.Module):
  683. def __init__(self, config: MobileViTV2Config, in_channels: int, out_channels: int) -> None:
  684. super().__init__()
  685. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  686. self.conv_1x1 = MobileViTV2ConvLayer(
  687. config,
  688. in_channels=in_channels,
  689. out_channels=out_channels,
  690. kernel_size=1,
  691. stride=1,
  692. use_normalization=True,
  693. use_activation="relu",
  694. )
  695. def forward(self, features: torch.Tensor) -> torch.Tensor:
  696. spatial_size = features.shape[-2:]
  697. features = self.global_pool(features)
  698. features = self.conv_1x1(features)
  699. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  700. return features
  701. class MobileViTV2ASPP(nn.Module):
  702. """
  703. ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
  704. """
  705. def __init__(self, config: MobileViTV2Config) -> None:
  706. super().__init__()
  707. encoder_out_channels = make_divisible(512 * config.width_multiplier, divisor=8) # layer 5 output dimension
  708. in_channels = encoder_out_channels
  709. out_channels = config.aspp_out_channels
  710. if len(config.atrous_rates) != 3:
  711. raise ValueError("Expected 3 values for atrous_rates")
  712. self.convs = nn.ModuleList()
  713. in_projection = MobileViTV2ConvLayer(
  714. config,
  715. in_channels=in_channels,
  716. out_channels=out_channels,
  717. kernel_size=1,
  718. use_activation="relu",
  719. )
  720. self.convs.append(in_projection)
  721. self.convs.extend(
  722. [
  723. MobileViTV2ConvLayer(
  724. config,
  725. in_channels=in_channels,
  726. out_channels=out_channels,
  727. kernel_size=3,
  728. dilation=rate,
  729. use_activation="relu",
  730. )
  731. for rate in config.atrous_rates
  732. ]
  733. )
  734. pool_layer = MobileViTV2ASPPPooling(config, in_channels, out_channels)
  735. self.convs.append(pool_layer)
  736. self.project = MobileViTV2ConvLayer(
  737. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  738. )
  739. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  740. def forward(self, features: torch.Tensor) -> torch.Tensor:
  741. pyramid = []
  742. for conv in self.convs:
  743. pyramid.append(conv(features))
  744. pyramid = torch.cat(pyramid, dim=1)
  745. pooled_features = self.project(pyramid)
  746. pooled_features = self.dropout(pooled_features)
  747. return pooled_features
  748. # Copied from transformers.models.mobilevit.modeling_mobilevit.MobileViTDeepLabV3 with MobileViT->MobileViTV2
  749. class MobileViTV2DeepLabV3(nn.Module):
  750. """
  751. DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
  752. """
  753. def __init__(self, config: MobileViTV2Config) -> None:
  754. super().__init__()
  755. self.aspp = MobileViTV2ASPP(config)
  756. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  757. self.classifier = MobileViTV2ConvLayer(
  758. config,
  759. in_channels=config.aspp_out_channels,
  760. out_channels=config.num_labels,
  761. kernel_size=1,
  762. use_normalization=False,
  763. use_activation=False,
  764. bias=True,
  765. )
  766. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  767. features = self.aspp(hidden_states[-1])
  768. features = self.dropout(features)
  769. features = self.classifier(features)
  770. return features
  771. @add_start_docstrings(
  772. """
  773. MobileViTV2 model with a semantic segmentation head on top, e.g. for Pascal VOC.
  774. """,
  775. MOBILEVITV2_START_DOCSTRING,
  776. )
  777. class MobileViTV2ForSemanticSegmentation(MobileViTV2PreTrainedModel):
  778. def __init__(self, config: MobileViTV2Config) -> None:
  779. super().__init__(config)
  780. self.num_labels = config.num_labels
  781. self.mobilevitv2 = MobileViTV2Model(config, expand_output=False)
  782. self.segmentation_head = MobileViTV2DeepLabV3(config)
  783. # Initialize weights and apply final processing
  784. self.post_init()
  785. @add_start_docstrings_to_model_forward(MOBILEVITV2_INPUTS_DOCSTRING)
  786. @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
  787. def forward(
  788. self,
  789. pixel_values: Optional[torch.Tensor] = None,
  790. labels: Optional[torch.Tensor] = None,
  791. output_hidden_states: Optional[bool] = None,
  792. return_dict: Optional[bool] = None,
  793. ) -> Union[tuple, SemanticSegmenterOutput]:
  794. r"""
  795. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  796. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  797. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  798. Returns:
  799. Examples:
  800. ```python
  801. >>> import requests
  802. >>> import torch
  803. >>> from PIL import Image
  804. >>> from transformers import AutoImageProcessor, MobileViTV2ForSemanticSegmentation
  805. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  806. >>> image = Image.open(requests.get(url, stream=True).raw)
  807. >>> image_processor = AutoImageProcessor.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  808. >>> model = MobileViTV2ForSemanticSegmentation.from_pretrained("apple/mobilevitv2-1.0-imagenet1k-256")
  809. >>> inputs = image_processor(images=image, return_tensors="pt")
  810. >>> with torch.no_grad():
  811. ... outputs = model(**inputs)
  812. >>> # logits are of shape (batch_size, num_labels, height, width)
  813. >>> logits = outputs.logits
  814. ```"""
  815. output_hidden_states = (
  816. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  817. )
  818. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  819. if labels is not None and self.config.num_labels == 1:
  820. raise ValueError("The number of labels should be greater than one")
  821. outputs = self.mobilevitv2(
  822. pixel_values,
  823. output_hidden_states=True, # we need the intermediate hidden states
  824. return_dict=return_dict,
  825. )
  826. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  827. logits = self.segmentation_head(encoder_hidden_states)
  828. loss = None
  829. if labels is not None:
  830. # upsample logits to the images' original size
  831. upsampled_logits = nn.functional.interpolate(
  832. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  833. )
  834. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  835. loss = loss_fct(upsampled_logits, labels)
  836. if not return_dict:
  837. if output_hidden_states:
  838. output = (logits,) + outputs[1:]
  839. else:
  840. output = (logits,) + outputs[2:]
  841. return ((loss,) + output) if loss is not None else output
  842. return SemanticSegmenterOutput(
  843. loss=loss,
  844. logits=logits,
  845. hidden_states=outputs.hidden_states if output_hidden_states else None,
  846. attentions=None,
  847. )