modeling_mobilevit.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072
  1. # coding=utf-8
  2. # Copyright 2022 Apple Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # Original license: https://github.com/apple/ml-cvnets/blob/main/LICENSE
  17. """PyTorch MobileViT model."""
  18. import math
  19. from typing import Dict, Optional, Set, Tuple, Union
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithNoAttention,
  27. BaseModelOutputWithPoolingAndNoAttention,
  28. ImageClassifierOutputWithNoAttention,
  29. SemanticSegmenterOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  33. from ...utils import (
  34. add_code_sample_docstrings,
  35. add_start_docstrings,
  36. add_start_docstrings_to_model_forward,
  37. logging,
  38. replace_return_docstrings,
  39. torch_int,
  40. )
  41. from .configuration_mobilevit import MobileViTConfig
  42. logger = logging.get_logger(__name__)
  43. # General docstring
  44. _CONFIG_FOR_DOC = "MobileViTConfig"
  45. # Base docstring
  46. _CHECKPOINT_FOR_DOC = "apple/mobilevit-small"
  47. _EXPECTED_OUTPUT_SHAPE = [1, 640, 8, 8]
  48. # Image classification docstring
  49. _IMAGE_CLASS_CHECKPOINT = "apple/mobilevit-small"
  50. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  51. def make_divisible(value: int, divisor: int = 8, min_value: Optional[int] = None) -> int:
  52. """
  53. Ensure that all layers have a channel count that is divisible by `divisor`. This function is taken from the
  54. original TensorFlow repo. It can be seen here:
  55. https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
  56. """
  57. if min_value is None:
  58. min_value = divisor
  59. new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
  60. # Make sure that round down does not go down by more than 10%.
  61. if new_value < 0.9 * value:
  62. new_value += divisor
  63. return int(new_value)
  64. class MobileViTConvLayer(nn.Module):
  65. def __init__(
  66. self,
  67. config: MobileViTConfig,
  68. in_channels: int,
  69. out_channels: int,
  70. kernel_size: int,
  71. stride: int = 1,
  72. groups: int = 1,
  73. bias: bool = False,
  74. dilation: int = 1,
  75. use_normalization: bool = True,
  76. use_activation: Union[bool, str] = True,
  77. ) -> None:
  78. super().__init__()
  79. padding = int((kernel_size - 1) / 2) * dilation
  80. if in_channels % groups != 0:
  81. raise ValueError(f"Input channels ({in_channels}) are not divisible by {groups} groups.")
  82. if out_channels % groups != 0:
  83. raise ValueError(f"Output channels ({out_channels}) are not divisible by {groups} groups.")
  84. self.convolution = nn.Conv2d(
  85. in_channels=in_channels,
  86. out_channels=out_channels,
  87. kernel_size=kernel_size,
  88. stride=stride,
  89. padding=padding,
  90. dilation=dilation,
  91. groups=groups,
  92. bias=bias,
  93. padding_mode="zeros",
  94. )
  95. if use_normalization:
  96. self.normalization = nn.BatchNorm2d(
  97. num_features=out_channels,
  98. eps=1e-5,
  99. momentum=0.1,
  100. affine=True,
  101. track_running_stats=True,
  102. )
  103. else:
  104. self.normalization = None
  105. if use_activation:
  106. if isinstance(use_activation, str):
  107. self.activation = ACT2FN[use_activation]
  108. elif isinstance(config.hidden_act, str):
  109. self.activation = ACT2FN[config.hidden_act]
  110. else:
  111. self.activation = config.hidden_act
  112. else:
  113. self.activation = None
  114. def forward(self, features: torch.Tensor) -> torch.Tensor:
  115. features = self.convolution(features)
  116. if self.normalization is not None:
  117. features = self.normalization(features)
  118. if self.activation is not None:
  119. features = self.activation(features)
  120. return features
  121. class MobileViTInvertedResidual(nn.Module):
  122. """
  123. Inverted residual block (MobileNetv2): https://arxiv.org/abs/1801.04381
  124. """
  125. def __init__(
  126. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int, dilation: int = 1
  127. ) -> None:
  128. super().__init__()
  129. expanded_channels = make_divisible(int(round(in_channels * config.expand_ratio)), 8)
  130. if stride not in [1, 2]:
  131. raise ValueError(f"Invalid stride {stride}.")
  132. self.use_residual = (stride == 1) and (in_channels == out_channels)
  133. self.expand_1x1 = MobileViTConvLayer(
  134. config, in_channels=in_channels, out_channels=expanded_channels, kernel_size=1
  135. )
  136. self.conv_3x3 = MobileViTConvLayer(
  137. config,
  138. in_channels=expanded_channels,
  139. out_channels=expanded_channels,
  140. kernel_size=3,
  141. stride=stride,
  142. groups=expanded_channels,
  143. dilation=dilation,
  144. )
  145. self.reduce_1x1 = MobileViTConvLayer(
  146. config,
  147. in_channels=expanded_channels,
  148. out_channels=out_channels,
  149. kernel_size=1,
  150. use_activation=False,
  151. )
  152. def forward(self, features: torch.Tensor) -> torch.Tensor:
  153. residual = features
  154. features = self.expand_1x1(features)
  155. features = self.conv_3x3(features)
  156. features = self.reduce_1x1(features)
  157. return residual + features if self.use_residual else features
  158. class MobileViTMobileNetLayer(nn.Module):
  159. def __init__(
  160. self, config: MobileViTConfig, in_channels: int, out_channels: int, stride: int = 1, num_stages: int = 1
  161. ) -> None:
  162. super().__init__()
  163. self.layer = nn.ModuleList()
  164. for i in range(num_stages):
  165. layer = MobileViTInvertedResidual(
  166. config,
  167. in_channels=in_channels,
  168. out_channels=out_channels,
  169. stride=stride if i == 0 else 1,
  170. )
  171. self.layer.append(layer)
  172. in_channels = out_channels
  173. def forward(self, features: torch.Tensor) -> torch.Tensor:
  174. for layer_module in self.layer:
  175. features = layer_module(features)
  176. return features
  177. class MobileViTSelfAttention(nn.Module):
  178. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  179. super().__init__()
  180. if hidden_size % config.num_attention_heads != 0:
  181. raise ValueError(
  182. f"The hidden size {hidden_size,} is not a multiple of the number of attention "
  183. f"heads {config.num_attention_heads}."
  184. )
  185. self.num_attention_heads = config.num_attention_heads
  186. self.attention_head_size = int(hidden_size / config.num_attention_heads)
  187. self.all_head_size = self.num_attention_heads * self.attention_head_size
  188. self.query = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  189. self.key = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  190. self.value = nn.Linear(hidden_size, self.all_head_size, bias=config.qkv_bias)
  191. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  192. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  193. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  194. x = x.view(*new_x_shape)
  195. return x.permute(0, 2, 1, 3)
  196. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  197. mixed_query_layer = self.query(hidden_states)
  198. key_layer = self.transpose_for_scores(self.key(hidden_states))
  199. value_layer = self.transpose_for_scores(self.value(hidden_states))
  200. query_layer = self.transpose_for_scores(mixed_query_layer)
  201. # Take the dot product between "query" and "key" to get the raw attention scores.
  202. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  203. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  204. # Normalize the attention scores to probabilities.
  205. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  206. # This is actually dropping out entire tokens to attend to, which might
  207. # seem a bit unusual, but is taken from the original Transformer paper.
  208. attention_probs = self.dropout(attention_probs)
  209. context_layer = torch.matmul(attention_probs, value_layer)
  210. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  211. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  212. context_layer = context_layer.view(*new_context_layer_shape)
  213. return context_layer
  214. class MobileViTSelfOutput(nn.Module):
  215. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  216. super().__init__()
  217. self.dense = nn.Linear(hidden_size, hidden_size)
  218. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  219. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.dense(hidden_states)
  221. hidden_states = self.dropout(hidden_states)
  222. return hidden_states
  223. class MobileViTAttention(nn.Module):
  224. def __init__(self, config: MobileViTConfig, hidden_size: int) -> None:
  225. super().__init__()
  226. self.attention = MobileViTSelfAttention(config, hidden_size)
  227. self.output = MobileViTSelfOutput(config, hidden_size)
  228. self.pruned_heads = set()
  229. def prune_heads(self, heads: Set[int]) -> None:
  230. if len(heads) == 0:
  231. return
  232. heads, index = find_pruneable_heads_and_indices(
  233. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  234. )
  235. # Prune linear layers
  236. self.attention.query = prune_linear_layer(self.attention.query, index)
  237. self.attention.key = prune_linear_layer(self.attention.key, index)
  238. self.attention.value = prune_linear_layer(self.attention.value, index)
  239. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  240. # Update hyper params and store pruned heads
  241. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  242. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  243. self.pruned_heads = self.pruned_heads.union(heads)
  244. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  245. self_outputs = self.attention(hidden_states)
  246. attention_output = self.output(self_outputs)
  247. return attention_output
  248. class MobileViTIntermediate(nn.Module):
  249. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  250. super().__init__()
  251. self.dense = nn.Linear(hidden_size, intermediate_size)
  252. if isinstance(config.hidden_act, str):
  253. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  254. else:
  255. self.intermediate_act_fn = config.hidden_act
  256. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  257. hidden_states = self.dense(hidden_states)
  258. hidden_states = self.intermediate_act_fn(hidden_states)
  259. return hidden_states
  260. class MobileViTOutput(nn.Module):
  261. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  262. super().__init__()
  263. self.dense = nn.Linear(intermediate_size, hidden_size)
  264. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  265. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  266. hidden_states = self.dense(hidden_states)
  267. hidden_states = self.dropout(hidden_states)
  268. hidden_states = hidden_states + input_tensor
  269. return hidden_states
  270. class MobileViTTransformerLayer(nn.Module):
  271. def __init__(self, config: MobileViTConfig, hidden_size: int, intermediate_size: int) -> None:
  272. super().__init__()
  273. self.attention = MobileViTAttention(config, hidden_size)
  274. self.intermediate = MobileViTIntermediate(config, hidden_size, intermediate_size)
  275. self.output = MobileViTOutput(config, hidden_size, intermediate_size)
  276. self.layernorm_before = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  277. self.layernorm_after = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  278. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  279. attention_output = self.attention(self.layernorm_before(hidden_states))
  280. hidden_states = attention_output + hidden_states
  281. layer_output = self.layernorm_after(hidden_states)
  282. layer_output = self.intermediate(layer_output)
  283. layer_output = self.output(layer_output, hidden_states)
  284. return layer_output
  285. class MobileViTTransformer(nn.Module):
  286. def __init__(self, config: MobileViTConfig, hidden_size: int, num_stages: int) -> None:
  287. super().__init__()
  288. self.layer = nn.ModuleList()
  289. for _ in range(num_stages):
  290. transformer_layer = MobileViTTransformerLayer(
  291. config,
  292. hidden_size=hidden_size,
  293. intermediate_size=int(hidden_size * config.mlp_ratio),
  294. )
  295. self.layer.append(transformer_layer)
  296. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  297. for layer_module in self.layer:
  298. hidden_states = layer_module(hidden_states)
  299. return hidden_states
  300. class MobileViTLayer(nn.Module):
  301. """
  302. MobileViT block: https://arxiv.org/abs/2110.02178
  303. """
  304. def __init__(
  305. self,
  306. config: MobileViTConfig,
  307. in_channels: int,
  308. out_channels: int,
  309. stride: int,
  310. hidden_size: int,
  311. num_stages: int,
  312. dilation: int = 1,
  313. ) -> None:
  314. super().__init__()
  315. self.patch_width = config.patch_size
  316. self.patch_height = config.patch_size
  317. if stride == 2:
  318. self.downsampling_layer = MobileViTInvertedResidual(
  319. config,
  320. in_channels=in_channels,
  321. out_channels=out_channels,
  322. stride=stride if dilation == 1 else 1,
  323. dilation=dilation // 2 if dilation > 1 else 1,
  324. )
  325. in_channels = out_channels
  326. else:
  327. self.downsampling_layer = None
  328. self.conv_kxk = MobileViTConvLayer(
  329. config,
  330. in_channels=in_channels,
  331. out_channels=in_channels,
  332. kernel_size=config.conv_kernel_size,
  333. )
  334. self.conv_1x1 = MobileViTConvLayer(
  335. config,
  336. in_channels=in_channels,
  337. out_channels=hidden_size,
  338. kernel_size=1,
  339. use_normalization=False,
  340. use_activation=False,
  341. )
  342. self.transformer = MobileViTTransformer(
  343. config,
  344. hidden_size=hidden_size,
  345. num_stages=num_stages,
  346. )
  347. self.layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_eps)
  348. self.conv_projection = MobileViTConvLayer(
  349. config, in_channels=hidden_size, out_channels=in_channels, kernel_size=1
  350. )
  351. self.fusion = MobileViTConvLayer(
  352. config, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=config.conv_kernel_size
  353. )
  354. def unfolding(self, features: torch.Tensor) -> Tuple[torch.Tensor, Dict]:
  355. patch_width, patch_height = self.patch_width, self.patch_height
  356. patch_area = int(patch_width * patch_height)
  357. batch_size, channels, orig_height, orig_width = features.shape
  358. new_height = (
  359. torch_int(torch.ceil(orig_height / patch_height) * patch_height)
  360. if torch.jit.is_tracing()
  361. else int(math.ceil(orig_height / patch_height) * patch_height)
  362. )
  363. new_width = (
  364. torch_int(torch.ceil(orig_width / patch_width) * patch_width)
  365. if torch.jit.is_tracing()
  366. else int(math.ceil(orig_width / patch_width) * patch_width)
  367. )
  368. interpolate = False
  369. if new_width != orig_width or new_height != orig_height:
  370. # Note: Padding can be done, but then it needs to be handled in attention function.
  371. features = nn.functional.interpolate(
  372. features, size=(new_height, new_width), mode="bilinear", align_corners=False
  373. )
  374. interpolate = True
  375. # number of patches along width and height
  376. num_patch_width = new_width // patch_width
  377. num_patch_height = new_height // patch_height
  378. num_patches = num_patch_height * num_patch_width
  379. # convert from shape (batch_size, channels, orig_height, orig_width)
  380. # to the shape (batch_size * patch_area, num_patches, channels)
  381. patches = features.reshape(
  382. batch_size * channels * num_patch_height, patch_height, num_patch_width, patch_width
  383. )
  384. patches = patches.transpose(1, 2)
  385. patches = patches.reshape(batch_size, channels, num_patches, patch_area)
  386. patches = patches.transpose(1, 3)
  387. patches = patches.reshape(batch_size * patch_area, num_patches, -1)
  388. info_dict = {
  389. "orig_size": (orig_height, orig_width),
  390. "batch_size": batch_size,
  391. "channels": channels,
  392. "interpolate": interpolate,
  393. "num_patches": num_patches,
  394. "num_patches_width": num_patch_width,
  395. "num_patches_height": num_patch_height,
  396. }
  397. return patches, info_dict
  398. def folding(self, patches: torch.Tensor, info_dict: Dict) -> torch.Tensor:
  399. patch_width, patch_height = self.patch_width, self.patch_height
  400. patch_area = int(patch_width * patch_height)
  401. batch_size = info_dict["batch_size"]
  402. channels = info_dict["channels"]
  403. num_patches = info_dict["num_patches"]
  404. num_patch_height = info_dict["num_patches_height"]
  405. num_patch_width = info_dict["num_patches_width"]
  406. # convert from shape (batch_size * patch_area, num_patches, channels)
  407. # back to shape (batch_size, channels, orig_height, orig_width)
  408. features = patches.contiguous().view(batch_size, patch_area, num_patches, -1)
  409. features = features.transpose(1, 3)
  410. features = features.reshape(
  411. batch_size * channels * num_patch_height, num_patch_width, patch_height, patch_width
  412. )
  413. features = features.transpose(1, 2)
  414. features = features.reshape(
  415. batch_size, channels, num_patch_height * patch_height, num_patch_width * patch_width
  416. )
  417. if info_dict["interpolate"]:
  418. features = nn.functional.interpolate(
  419. features, size=info_dict["orig_size"], mode="bilinear", align_corners=False
  420. )
  421. return features
  422. def forward(self, features: torch.Tensor) -> torch.Tensor:
  423. # reduce spatial dimensions if needed
  424. if self.downsampling_layer:
  425. features = self.downsampling_layer(features)
  426. residual = features
  427. # local representation
  428. features = self.conv_kxk(features)
  429. features = self.conv_1x1(features)
  430. # convert feature map to patches
  431. patches, info_dict = self.unfolding(features)
  432. # learn global representations
  433. patches = self.transformer(patches)
  434. patches = self.layernorm(patches)
  435. # convert patches back to feature maps
  436. features = self.folding(patches, info_dict)
  437. features = self.conv_projection(features)
  438. features = self.fusion(torch.cat((residual, features), dim=1))
  439. return features
  440. class MobileViTEncoder(nn.Module):
  441. def __init__(self, config: MobileViTConfig) -> None:
  442. super().__init__()
  443. self.config = config
  444. self.layer = nn.ModuleList()
  445. self.gradient_checkpointing = False
  446. # segmentation architectures like DeepLab and PSPNet modify the strides
  447. # of the classification backbones
  448. dilate_layer_4 = dilate_layer_5 = False
  449. if config.output_stride == 8:
  450. dilate_layer_4 = True
  451. dilate_layer_5 = True
  452. elif config.output_stride == 16:
  453. dilate_layer_5 = True
  454. dilation = 1
  455. layer_1 = MobileViTMobileNetLayer(
  456. config,
  457. in_channels=config.neck_hidden_sizes[0],
  458. out_channels=config.neck_hidden_sizes[1],
  459. stride=1,
  460. num_stages=1,
  461. )
  462. self.layer.append(layer_1)
  463. layer_2 = MobileViTMobileNetLayer(
  464. config,
  465. in_channels=config.neck_hidden_sizes[1],
  466. out_channels=config.neck_hidden_sizes[2],
  467. stride=2,
  468. num_stages=3,
  469. )
  470. self.layer.append(layer_2)
  471. layer_3 = MobileViTLayer(
  472. config,
  473. in_channels=config.neck_hidden_sizes[2],
  474. out_channels=config.neck_hidden_sizes[3],
  475. stride=2,
  476. hidden_size=config.hidden_sizes[0],
  477. num_stages=2,
  478. )
  479. self.layer.append(layer_3)
  480. if dilate_layer_4:
  481. dilation *= 2
  482. layer_4 = MobileViTLayer(
  483. config,
  484. in_channels=config.neck_hidden_sizes[3],
  485. out_channels=config.neck_hidden_sizes[4],
  486. stride=2,
  487. hidden_size=config.hidden_sizes[1],
  488. num_stages=4,
  489. dilation=dilation,
  490. )
  491. self.layer.append(layer_4)
  492. if dilate_layer_5:
  493. dilation *= 2
  494. layer_5 = MobileViTLayer(
  495. config,
  496. in_channels=config.neck_hidden_sizes[4],
  497. out_channels=config.neck_hidden_sizes[5],
  498. stride=2,
  499. hidden_size=config.hidden_sizes[2],
  500. num_stages=3,
  501. dilation=dilation,
  502. )
  503. self.layer.append(layer_5)
  504. def forward(
  505. self,
  506. hidden_states: torch.Tensor,
  507. output_hidden_states: bool = False,
  508. return_dict: bool = True,
  509. ) -> Union[tuple, BaseModelOutputWithNoAttention]:
  510. all_hidden_states = () if output_hidden_states else None
  511. for i, layer_module in enumerate(self.layer):
  512. if self.gradient_checkpointing and self.training:
  513. hidden_states = self._gradient_checkpointing_func(
  514. layer_module.__call__,
  515. hidden_states,
  516. )
  517. else:
  518. hidden_states = layer_module(hidden_states)
  519. if output_hidden_states:
  520. all_hidden_states = all_hidden_states + (hidden_states,)
  521. if not return_dict:
  522. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  523. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  524. class MobileViTPreTrainedModel(PreTrainedModel):
  525. """
  526. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  527. models.
  528. """
  529. config_class = MobileViTConfig
  530. base_model_prefix = "mobilevit"
  531. main_input_name = "pixel_values"
  532. supports_gradient_checkpointing = True
  533. _no_split_modules = ["MobileViTLayer"]
  534. def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
  535. """Initialize the weights"""
  536. if isinstance(module, (nn.Linear, nn.Conv2d)):
  537. # Slightly different from the TF version which uses truncated_normal for initialization
  538. # cf https://github.com/pytorch/pytorch/pull/5617
  539. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  540. if module.bias is not None:
  541. module.bias.data.zero_()
  542. elif isinstance(module, nn.LayerNorm):
  543. module.bias.data.zero_()
  544. module.weight.data.fill_(1.0)
  545. MOBILEVIT_START_DOCSTRING = r"""
  546. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  547. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  548. behavior.
  549. Parameters:
  550. config ([`MobileViTConfig`]): Model configuration class with all the parameters of the model.
  551. Initializing with a config file does not load the weights associated with the model, only the
  552. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  553. """
  554. MOBILEVIT_INPUTS_DOCSTRING = r"""
  555. Args:
  556. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  557. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  558. [`MobileViTImageProcessor.__call__`] for details.
  559. output_hidden_states (`bool`, *optional*):
  560. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  561. more detail.
  562. return_dict (`bool`, *optional*):
  563. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  564. """
  565. @add_start_docstrings(
  566. "The bare MobileViT model outputting raw hidden-states without any specific head on top.",
  567. MOBILEVIT_START_DOCSTRING,
  568. )
  569. class MobileViTModel(MobileViTPreTrainedModel):
  570. def __init__(self, config: MobileViTConfig, expand_output: bool = True):
  571. super().__init__(config)
  572. self.config = config
  573. self.expand_output = expand_output
  574. self.conv_stem = MobileViTConvLayer(
  575. config,
  576. in_channels=config.num_channels,
  577. out_channels=config.neck_hidden_sizes[0],
  578. kernel_size=3,
  579. stride=2,
  580. )
  581. self.encoder = MobileViTEncoder(config)
  582. if self.expand_output:
  583. self.conv_1x1_exp = MobileViTConvLayer(
  584. config,
  585. in_channels=config.neck_hidden_sizes[5],
  586. out_channels=config.neck_hidden_sizes[6],
  587. kernel_size=1,
  588. )
  589. # Initialize weights and apply final processing
  590. self.post_init()
  591. def _prune_heads(self, heads_to_prune):
  592. """Prunes heads of the model.
  593. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel
  594. """
  595. for layer_index, heads in heads_to_prune.items():
  596. mobilevit_layer = self.encoder.layer[layer_index]
  597. if isinstance(mobilevit_layer, MobileViTLayer):
  598. for transformer_layer in mobilevit_layer.transformer.layer:
  599. transformer_layer.attention.prune_heads(heads)
  600. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  601. @add_code_sample_docstrings(
  602. checkpoint=_CHECKPOINT_FOR_DOC,
  603. output_type=BaseModelOutputWithPoolingAndNoAttention,
  604. config_class=_CONFIG_FOR_DOC,
  605. modality="vision",
  606. expected_output=_EXPECTED_OUTPUT_SHAPE,
  607. )
  608. def forward(
  609. self,
  610. pixel_values: Optional[torch.Tensor] = None,
  611. output_hidden_states: Optional[bool] = None,
  612. return_dict: Optional[bool] = None,
  613. ) -> Union[tuple, BaseModelOutputWithPoolingAndNoAttention]:
  614. output_hidden_states = (
  615. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  616. )
  617. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  618. if pixel_values is None:
  619. raise ValueError("You have to specify pixel_values")
  620. embedding_output = self.conv_stem(pixel_values)
  621. encoder_outputs = self.encoder(
  622. embedding_output,
  623. output_hidden_states=output_hidden_states,
  624. return_dict=return_dict,
  625. )
  626. if self.expand_output:
  627. last_hidden_state = self.conv_1x1_exp(encoder_outputs[0])
  628. # global average pooling: (batch_size, channels, height, width) -> (batch_size, channels)
  629. pooled_output = torch.mean(last_hidden_state, dim=[-2, -1], keepdim=False)
  630. else:
  631. last_hidden_state = encoder_outputs[0]
  632. pooled_output = None
  633. if not return_dict:
  634. output = (last_hidden_state, pooled_output) if pooled_output is not None else (last_hidden_state,)
  635. return output + encoder_outputs[1:]
  636. return BaseModelOutputWithPoolingAndNoAttention(
  637. last_hidden_state=last_hidden_state,
  638. pooler_output=pooled_output,
  639. hidden_states=encoder_outputs.hidden_states,
  640. )
  641. @add_start_docstrings(
  642. """
  643. MobileViT model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  644. ImageNet.
  645. """,
  646. MOBILEVIT_START_DOCSTRING,
  647. )
  648. class MobileViTForImageClassification(MobileViTPreTrainedModel):
  649. def __init__(self, config: MobileViTConfig) -> None:
  650. super().__init__(config)
  651. self.num_labels = config.num_labels
  652. self.mobilevit = MobileViTModel(config)
  653. # Classifier head
  654. self.dropout = nn.Dropout(config.classifier_dropout_prob, inplace=True)
  655. self.classifier = (
  656. nn.Linear(config.neck_hidden_sizes[-1], config.num_labels) if config.num_labels > 0 else nn.Identity()
  657. )
  658. # Initialize weights and apply final processing
  659. self.post_init()
  660. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  661. @add_code_sample_docstrings(
  662. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  663. output_type=ImageClassifierOutputWithNoAttention,
  664. config_class=_CONFIG_FOR_DOC,
  665. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  666. )
  667. def forward(
  668. self,
  669. pixel_values: Optional[torch.Tensor] = None,
  670. output_hidden_states: Optional[bool] = None,
  671. labels: Optional[torch.Tensor] = None,
  672. return_dict: Optional[bool] = None,
  673. ) -> Union[tuple, ImageClassifierOutputWithNoAttention]:
  674. r"""
  675. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  676. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  677. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss). If
  678. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  679. """
  680. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  681. outputs = self.mobilevit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  682. pooled_output = outputs.pooler_output if return_dict else outputs[1]
  683. logits = self.classifier(self.dropout(pooled_output))
  684. loss = None
  685. if labels is not None:
  686. if self.config.problem_type is None:
  687. if self.num_labels == 1:
  688. self.config.problem_type = "regression"
  689. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  690. self.config.problem_type = "single_label_classification"
  691. else:
  692. self.config.problem_type = "multi_label_classification"
  693. if self.config.problem_type == "regression":
  694. loss_fct = MSELoss()
  695. if self.num_labels == 1:
  696. loss = loss_fct(logits.squeeze(), labels.squeeze())
  697. else:
  698. loss = loss_fct(logits, labels)
  699. elif self.config.problem_type == "single_label_classification":
  700. loss_fct = CrossEntropyLoss()
  701. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  702. elif self.config.problem_type == "multi_label_classification":
  703. loss_fct = BCEWithLogitsLoss()
  704. loss = loss_fct(logits, labels)
  705. if not return_dict:
  706. output = (logits,) + outputs[2:]
  707. return ((loss,) + output) if loss is not None else output
  708. return ImageClassifierOutputWithNoAttention(
  709. loss=loss,
  710. logits=logits,
  711. hidden_states=outputs.hidden_states,
  712. )
  713. class MobileViTASPPPooling(nn.Module):
  714. def __init__(self, config: MobileViTConfig, in_channels: int, out_channels: int) -> None:
  715. super().__init__()
  716. self.global_pool = nn.AdaptiveAvgPool2d(output_size=1)
  717. self.conv_1x1 = MobileViTConvLayer(
  718. config,
  719. in_channels=in_channels,
  720. out_channels=out_channels,
  721. kernel_size=1,
  722. stride=1,
  723. use_normalization=True,
  724. use_activation="relu",
  725. )
  726. def forward(self, features: torch.Tensor) -> torch.Tensor:
  727. spatial_size = features.shape[-2:]
  728. features = self.global_pool(features)
  729. features = self.conv_1x1(features)
  730. features = nn.functional.interpolate(features, size=spatial_size, mode="bilinear", align_corners=False)
  731. return features
  732. class MobileViTASPP(nn.Module):
  733. """
  734. ASPP module defined in DeepLab papers: https://arxiv.org/abs/1606.00915, https://arxiv.org/abs/1706.05587
  735. """
  736. def __init__(self, config: MobileViTConfig) -> None:
  737. super().__init__()
  738. in_channels = config.neck_hidden_sizes[-2]
  739. out_channels = config.aspp_out_channels
  740. if len(config.atrous_rates) != 3:
  741. raise ValueError("Expected 3 values for atrous_rates")
  742. self.convs = nn.ModuleList()
  743. in_projection = MobileViTConvLayer(
  744. config,
  745. in_channels=in_channels,
  746. out_channels=out_channels,
  747. kernel_size=1,
  748. use_activation="relu",
  749. )
  750. self.convs.append(in_projection)
  751. self.convs.extend(
  752. [
  753. MobileViTConvLayer(
  754. config,
  755. in_channels=in_channels,
  756. out_channels=out_channels,
  757. kernel_size=3,
  758. dilation=rate,
  759. use_activation="relu",
  760. )
  761. for rate in config.atrous_rates
  762. ]
  763. )
  764. pool_layer = MobileViTASPPPooling(config, in_channels, out_channels)
  765. self.convs.append(pool_layer)
  766. self.project = MobileViTConvLayer(
  767. config, in_channels=5 * out_channels, out_channels=out_channels, kernel_size=1, use_activation="relu"
  768. )
  769. self.dropout = nn.Dropout(p=config.aspp_dropout_prob)
  770. def forward(self, features: torch.Tensor) -> torch.Tensor:
  771. pyramid = []
  772. for conv in self.convs:
  773. pyramid.append(conv(features))
  774. pyramid = torch.cat(pyramid, dim=1)
  775. pooled_features = self.project(pyramid)
  776. pooled_features = self.dropout(pooled_features)
  777. return pooled_features
  778. class MobileViTDeepLabV3(nn.Module):
  779. """
  780. DeepLabv3 architecture: https://arxiv.org/abs/1706.05587
  781. """
  782. def __init__(self, config: MobileViTConfig) -> None:
  783. super().__init__()
  784. self.aspp = MobileViTASPP(config)
  785. self.dropout = nn.Dropout2d(config.classifier_dropout_prob)
  786. self.classifier = MobileViTConvLayer(
  787. config,
  788. in_channels=config.aspp_out_channels,
  789. out_channels=config.num_labels,
  790. kernel_size=1,
  791. use_normalization=False,
  792. use_activation=False,
  793. bias=True,
  794. )
  795. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  796. features = self.aspp(hidden_states[-1])
  797. features = self.dropout(features)
  798. features = self.classifier(features)
  799. return features
  800. @add_start_docstrings(
  801. """
  802. MobileViT model with a semantic segmentation head on top, e.g. for Pascal VOC.
  803. """,
  804. MOBILEVIT_START_DOCSTRING,
  805. )
  806. class MobileViTForSemanticSegmentation(MobileViTPreTrainedModel):
  807. def __init__(self, config: MobileViTConfig) -> None:
  808. super().__init__(config)
  809. self.num_labels = config.num_labels
  810. self.mobilevit = MobileViTModel(config, expand_output=False)
  811. self.segmentation_head = MobileViTDeepLabV3(config)
  812. # Initialize weights and apply final processing
  813. self.post_init()
  814. @add_start_docstrings_to_model_forward(MOBILEVIT_INPUTS_DOCSTRING)
  815. @replace_return_docstrings(output_type=SemanticSegmenterOutput, config_class=_CONFIG_FOR_DOC)
  816. def forward(
  817. self,
  818. pixel_values: Optional[torch.Tensor] = None,
  819. labels: Optional[torch.Tensor] = None,
  820. output_hidden_states: Optional[bool] = None,
  821. return_dict: Optional[bool] = None,
  822. ) -> Union[tuple, SemanticSegmenterOutput]:
  823. r"""
  824. labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
  825. Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
  826. config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
  827. Returns:
  828. Examples:
  829. ```python
  830. >>> import requests
  831. >>> import torch
  832. >>> from PIL import Image
  833. >>> from transformers import AutoImageProcessor, MobileViTForSemanticSegmentation
  834. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  835. >>> image = Image.open(requests.get(url, stream=True).raw)
  836. >>> image_processor = AutoImageProcessor.from_pretrained("apple/deeplabv3-mobilevit-small")
  837. >>> model = MobileViTForSemanticSegmentation.from_pretrained("apple/deeplabv3-mobilevit-small")
  838. >>> inputs = image_processor(images=image, return_tensors="pt")
  839. >>> with torch.no_grad():
  840. ... outputs = model(**inputs)
  841. >>> # logits are of shape (batch_size, num_labels, height, width)
  842. >>> logits = outputs.logits
  843. ```"""
  844. output_hidden_states = (
  845. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  846. )
  847. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  848. if labels is not None and self.config.num_labels == 1:
  849. raise ValueError("The number of labels should be greater than one")
  850. outputs = self.mobilevit(
  851. pixel_values,
  852. output_hidden_states=True, # we need the intermediate hidden states
  853. return_dict=return_dict,
  854. )
  855. encoder_hidden_states = outputs.hidden_states if return_dict else outputs[1]
  856. logits = self.segmentation_head(encoder_hidden_states)
  857. loss = None
  858. if labels is not None:
  859. # upsample logits to the images' original size
  860. upsampled_logits = nn.functional.interpolate(
  861. logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
  862. )
  863. loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
  864. loss = loss_fct(upsampled_logits, labels)
  865. if not return_dict:
  866. if output_hidden_states:
  867. output = (logits,) + outputs[1:]
  868. else:
  869. output = (logits,) + outputs[2:]
  870. return ((loss,) + output) if loss is not None else output
  871. return SemanticSegmenterOutput(
  872. loss=loss,
  873. logits=logits,
  874. hidden_states=outputs.hidden_states if output_hidden_states else None,
  875. attentions=None,
  876. )