modeling_levit.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms, Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LeViT model."""
  16. import itertools
  17. from dataclasses import dataclass
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithNoAttention,
  25. BaseModelOutputWithPoolingAndNoAttention,
  26. ImageClassifierOutputWithNoAttention,
  27. ModelOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  31. from .configuration_levit import LevitConfig
  32. logger = logging.get_logger(__name__)
  33. # General docstring
  34. _CONFIG_FOR_DOC = "LevitConfig"
  35. # Base docstring
  36. _CHECKPOINT_FOR_DOC = "facebook/levit-128S"
  37. _EXPECTED_OUTPUT_SHAPE = [1, 16, 384]
  38. # Image classification docstring
  39. _IMAGE_CLASS_CHECKPOINT = "facebook/levit-128S"
  40. _IMAGE_CLASS_EXPECTED_OUTPUT = "tabby, tabby cat"
  41. @dataclass
  42. class LevitForImageClassificationWithTeacherOutput(ModelOutput):
  43. """
  44. Output type of [`LevitForImageClassificationWithTeacher`].
  45. Args:
  46. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  47. Prediction scores as the average of the `cls_logits` and `distillation_logits`.
  48. cls_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  49. Prediction scores of the classification head (i.e. the linear layer on top of the final hidden state of the
  50. class token).
  51. distillation_logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  52. Prediction scores of the distillation head (i.e. the linear layer on top of the final hidden state of the
  53. distillation token).
  54. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  55. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  56. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
  57. plus the initial embedding outputs.
  58. """
  59. logits: torch.FloatTensor = None
  60. cls_logits: torch.FloatTensor = None
  61. distillation_logits: torch.FloatTensor = None
  62. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  63. class LevitConvEmbeddings(nn.Module):
  64. """
  65. LeViT Conv Embeddings with Batch Norm, used in the initial patch embedding layer.
  66. """
  67. def __init__(
  68. self, in_channels, out_channels, kernel_size, stride, padding, dilation=1, groups=1, bn_weight_init=1
  69. ):
  70. super().__init__()
  71. self.convolution = nn.Conv2d(
  72. in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False
  73. )
  74. self.batch_norm = nn.BatchNorm2d(out_channels)
  75. def forward(self, embeddings):
  76. embeddings = self.convolution(embeddings)
  77. embeddings = self.batch_norm(embeddings)
  78. return embeddings
  79. class LevitPatchEmbeddings(nn.Module):
  80. """
  81. LeViT patch embeddings, for final embeddings to be passed to transformer blocks. It consists of multiple
  82. `LevitConvEmbeddings`.
  83. """
  84. def __init__(self, config):
  85. super().__init__()
  86. self.embedding_layer_1 = LevitConvEmbeddings(
  87. config.num_channels, config.hidden_sizes[0] // 8, config.kernel_size, config.stride, config.padding
  88. )
  89. self.activation_layer_1 = nn.Hardswish()
  90. self.embedding_layer_2 = LevitConvEmbeddings(
  91. config.hidden_sizes[0] // 8, config.hidden_sizes[0] // 4, config.kernel_size, config.stride, config.padding
  92. )
  93. self.activation_layer_2 = nn.Hardswish()
  94. self.embedding_layer_3 = LevitConvEmbeddings(
  95. config.hidden_sizes[0] // 4, config.hidden_sizes[0] // 2, config.kernel_size, config.stride, config.padding
  96. )
  97. self.activation_layer_3 = nn.Hardswish()
  98. self.embedding_layer_4 = LevitConvEmbeddings(
  99. config.hidden_sizes[0] // 2, config.hidden_sizes[0], config.kernel_size, config.stride, config.padding
  100. )
  101. self.num_channels = config.num_channels
  102. def forward(self, pixel_values):
  103. num_channels = pixel_values.shape[1]
  104. if num_channels != self.num_channels:
  105. raise ValueError(
  106. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  107. )
  108. embeddings = self.embedding_layer_1(pixel_values)
  109. embeddings = self.activation_layer_1(embeddings)
  110. embeddings = self.embedding_layer_2(embeddings)
  111. embeddings = self.activation_layer_2(embeddings)
  112. embeddings = self.embedding_layer_3(embeddings)
  113. embeddings = self.activation_layer_3(embeddings)
  114. embeddings = self.embedding_layer_4(embeddings)
  115. return embeddings.flatten(2).transpose(1, 2)
  116. class MLPLayerWithBN(nn.Module):
  117. def __init__(self, input_dim, output_dim, bn_weight_init=1):
  118. super().__init__()
  119. self.linear = nn.Linear(in_features=input_dim, out_features=output_dim, bias=False)
  120. self.batch_norm = nn.BatchNorm1d(output_dim)
  121. def forward(self, hidden_state):
  122. hidden_state = self.linear(hidden_state)
  123. hidden_state = self.batch_norm(hidden_state.flatten(0, 1)).reshape_as(hidden_state)
  124. return hidden_state
  125. class LevitSubsample(nn.Module):
  126. def __init__(self, stride, resolution):
  127. super().__init__()
  128. self.stride = stride
  129. self.resolution = resolution
  130. def forward(self, hidden_state):
  131. batch_size, _, channels = hidden_state.shape
  132. hidden_state = hidden_state.view(batch_size, self.resolution, self.resolution, channels)[
  133. :, :: self.stride, :: self.stride
  134. ].reshape(batch_size, -1, channels)
  135. return hidden_state
  136. class LevitAttention(nn.Module):
  137. def __init__(self, hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution):
  138. super().__init__()
  139. self.num_attention_heads = num_attention_heads
  140. self.scale = key_dim**-0.5
  141. self.key_dim = key_dim
  142. self.attention_ratio = attention_ratio
  143. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads * 2
  144. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  145. self.queries_keys_values = MLPLayerWithBN(hidden_sizes, self.out_dim_keys_values)
  146. self.activation = nn.Hardswish()
  147. self.projection = MLPLayerWithBN(self.out_dim_projection, hidden_sizes, bn_weight_init=0)
  148. points = list(itertools.product(range(resolution), range(resolution)))
  149. len_points = len(points)
  150. attention_offsets, indices = {}, []
  151. for p1 in points:
  152. for p2 in points:
  153. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  154. if offset not in attention_offsets:
  155. attention_offsets[offset] = len(attention_offsets)
  156. indices.append(attention_offsets[offset])
  157. self.attention_bias_cache = {}
  158. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  159. self.register_buffer(
  160. "attention_bias_idxs", torch.LongTensor(indices).view(len_points, len_points), persistent=False
  161. )
  162. @torch.no_grad()
  163. def train(self, mode=True):
  164. super().train(mode)
  165. if mode and self.attention_bias_cache:
  166. self.attention_bias_cache = {} # clear ab cache
  167. def get_attention_biases(self, device):
  168. if self.training:
  169. return self.attention_biases[:, self.attention_bias_idxs]
  170. else:
  171. device_key = str(device)
  172. if device_key not in self.attention_bias_cache:
  173. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  174. return self.attention_bias_cache[device_key]
  175. def forward(self, hidden_state):
  176. batch_size, seq_length, _ = hidden_state.shape
  177. queries_keys_values = self.queries_keys_values(hidden_state)
  178. query, key, value = queries_keys_values.view(batch_size, seq_length, self.num_attention_heads, -1).split(
  179. [self.key_dim, self.key_dim, self.attention_ratio * self.key_dim], dim=3
  180. )
  181. query = query.permute(0, 2, 1, 3)
  182. key = key.permute(0, 2, 1, 3)
  183. value = value.permute(0, 2, 1, 3)
  184. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  185. attention = attention.softmax(dim=-1)
  186. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, seq_length, self.out_dim_projection)
  187. hidden_state = self.projection(self.activation(hidden_state))
  188. return hidden_state
  189. class LevitAttentionSubsample(nn.Module):
  190. def __init__(
  191. self,
  192. input_dim,
  193. output_dim,
  194. key_dim,
  195. num_attention_heads,
  196. attention_ratio,
  197. stride,
  198. resolution_in,
  199. resolution_out,
  200. ):
  201. super().__init__()
  202. self.num_attention_heads = num_attention_heads
  203. self.scale = key_dim**-0.5
  204. self.key_dim = key_dim
  205. self.attention_ratio = attention_ratio
  206. self.out_dim_keys_values = attention_ratio * key_dim * num_attention_heads + key_dim * num_attention_heads
  207. self.out_dim_projection = attention_ratio * key_dim * num_attention_heads
  208. self.resolution_out = resolution_out
  209. # resolution_in is the intial resolution, resoloution_out is final resolution after downsampling
  210. self.keys_values = MLPLayerWithBN(input_dim, self.out_dim_keys_values)
  211. self.queries_subsample = LevitSubsample(stride, resolution_in)
  212. self.queries = MLPLayerWithBN(input_dim, key_dim * num_attention_heads)
  213. self.activation = nn.Hardswish()
  214. self.projection = MLPLayerWithBN(self.out_dim_projection, output_dim)
  215. self.attention_bias_cache = {}
  216. points = list(itertools.product(range(resolution_in), range(resolution_in)))
  217. points_ = list(itertools.product(range(resolution_out), range(resolution_out)))
  218. len_points, len_points_ = len(points), len(points_)
  219. attention_offsets, indices = {}, []
  220. for p1 in points_:
  221. for p2 in points:
  222. size = 1
  223. offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2), abs(p1[1] * stride - p2[1] + (size - 1) / 2))
  224. if offset not in attention_offsets:
  225. attention_offsets[offset] = len(attention_offsets)
  226. indices.append(attention_offsets[offset])
  227. self.attention_biases = torch.nn.Parameter(torch.zeros(num_attention_heads, len(attention_offsets)))
  228. self.register_buffer(
  229. "attention_bias_idxs", torch.LongTensor(indices).view(len_points_, len_points), persistent=False
  230. )
  231. @torch.no_grad()
  232. def train(self, mode=True):
  233. super().train(mode)
  234. if mode and self.attention_bias_cache:
  235. self.attention_bias_cache = {} # clear ab cache
  236. def get_attention_biases(self, device):
  237. if self.training:
  238. return self.attention_biases[:, self.attention_bias_idxs]
  239. else:
  240. device_key = str(device)
  241. if device_key not in self.attention_bias_cache:
  242. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  243. return self.attention_bias_cache[device_key]
  244. def forward(self, hidden_state):
  245. batch_size, seq_length, _ = hidden_state.shape
  246. key, value = (
  247. self.keys_values(hidden_state)
  248. .view(batch_size, seq_length, self.num_attention_heads, -1)
  249. .split([self.key_dim, self.attention_ratio * self.key_dim], dim=3)
  250. )
  251. key = key.permute(0, 2, 1, 3)
  252. value = value.permute(0, 2, 1, 3)
  253. query = self.queries(self.queries_subsample(hidden_state))
  254. query = query.view(batch_size, self.resolution_out**2, self.num_attention_heads, self.key_dim).permute(
  255. 0, 2, 1, 3
  256. )
  257. attention = query @ key.transpose(-2, -1) * self.scale + self.get_attention_biases(hidden_state.device)
  258. attention = attention.softmax(dim=-1)
  259. hidden_state = (attention @ value).transpose(1, 2).reshape(batch_size, -1, self.out_dim_projection)
  260. hidden_state = self.projection(self.activation(hidden_state))
  261. return hidden_state
  262. class LevitMLPLayer(nn.Module):
  263. """
  264. MLP Layer with `2X` expansion in contrast to ViT with `4X`.
  265. """
  266. def __init__(self, input_dim, hidden_dim):
  267. super().__init__()
  268. self.linear_up = MLPLayerWithBN(input_dim, hidden_dim)
  269. self.activation = nn.Hardswish()
  270. self.linear_down = MLPLayerWithBN(hidden_dim, input_dim)
  271. def forward(self, hidden_state):
  272. hidden_state = self.linear_up(hidden_state)
  273. hidden_state = self.activation(hidden_state)
  274. hidden_state = self.linear_down(hidden_state)
  275. return hidden_state
  276. class LevitResidualLayer(nn.Module):
  277. """
  278. Residual Block for LeViT
  279. """
  280. def __init__(self, module, drop_rate):
  281. super().__init__()
  282. self.module = module
  283. self.drop_rate = drop_rate
  284. def forward(self, hidden_state):
  285. if self.training and self.drop_rate > 0:
  286. rnd = torch.rand(hidden_state.size(0), 1, 1, device=hidden_state.device)
  287. rnd = rnd.ge_(self.drop_rate).div(1 - self.drop_rate).detach()
  288. hidden_state = hidden_state + self.module(hidden_state) * rnd
  289. return hidden_state
  290. else:
  291. hidden_state = hidden_state + self.module(hidden_state)
  292. return hidden_state
  293. class LevitStage(nn.Module):
  294. """
  295. LeViT Stage consisting of `LevitMLPLayer` and `LevitAttention` layers.
  296. """
  297. def __init__(
  298. self,
  299. config,
  300. idx,
  301. hidden_sizes,
  302. key_dim,
  303. depths,
  304. num_attention_heads,
  305. attention_ratio,
  306. mlp_ratio,
  307. down_ops,
  308. resolution_in,
  309. ):
  310. super().__init__()
  311. self.layers = []
  312. self.config = config
  313. self.resolution_in = resolution_in
  314. # resolution_in is the intial resolution, resolution_out is final resolution after downsampling
  315. for _ in range(depths):
  316. self.layers.append(
  317. LevitResidualLayer(
  318. LevitAttention(hidden_sizes, key_dim, num_attention_heads, attention_ratio, resolution_in),
  319. self.config.drop_path_rate,
  320. )
  321. )
  322. if mlp_ratio > 0:
  323. hidden_dim = hidden_sizes * mlp_ratio
  324. self.layers.append(
  325. LevitResidualLayer(LevitMLPLayer(hidden_sizes, hidden_dim), self.config.drop_path_rate)
  326. )
  327. if down_ops[0] == "Subsample":
  328. self.resolution_out = (self.resolution_in - 1) // down_ops[5] + 1
  329. self.layers.append(
  330. LevitAttentionSubsample(
  331. *self.config.hidden_sizes[idx : idx + 2],
  332. key_dim=down_ops[1],
  333. num_attention_heads=down_ops[2],
  334. attention_ratio=down_ops[3],
  335. stride=down_ops[5],
  336. resolution_in=resolution_in,
  337. resolution_out=self.resolution_out,
  338. )
  339. )
  340. self.resolution_in = self.resolution_out
  341. if down_ops[4] > 0:
  342. hidden_dim = self.config.hidden_sizes[idx + 1] * down_ops[4]
  343. self.layers.append(
  344. LevitResidualLayer(
  345. LevitMLPLayer(self.config.hidden_sizes[idx + 1], hidden_dim), self.config.drop_path_rate
  346. )
  347. )
  348. self.layers = nn.ModuleList(self.layers)
  349. def get_resolution(self):
  350. return self.resolution_in
  351. def forward(self, hidden_state):
  352. for layer in self.layers:
  353. hidden_state = layer(hidden_state)
  354. return hidden_state
  355. class LevitEncoder(nn.Module):
  356. """
  357. LeViT Encoder consisting of multiple `LevitStage` stages.
  358. """
  359. def __init__(self, config):
  360. super().__init__()
  361. self.config = config
  362. resolution = self.config.image_size // self.config.patch_size
  363. self.stages = []
  364. self.config.down_ops.append([""])
  365. for stage_idx in range(len(config.depths)):
  366. stage = LevitStage(
  367. config,
  368. stage_idx,
  369. config.hidden_sizes[stage_idx],
  370. config.key_dim[stage_idx],
  371. config.depths[stage_idx],
  372. config.num_attention_heads[stage_idx],
  373. config.attention_ratio[stage_idx],
  374. config.mlp_ratio[stage_idx],
  375. config.down_ops[stage_idx],
  376. resolution,
  377. )
  378. resolution = stage.get_resolution()
  379. self.stages.append(stage)
  380. self.stages = nn.ModuleList(self.stages)
  381. def forward(self, hidden_state, output_hidden_states=False, return_dict=True):
  382. all_hidden_states = () if output_hidden_states else None
  383. for stage in self.stages:
  384. if output_hidden_states:
  385. all_hidden_states = all_hidden_states + (hidden_state,)
  386. hidden_state = stage(hidden_state)
  387. if output_hidden_states:
  388. all_hidden_states = all_hidden_states + (hidden_state,)
  389. if not return_dict:
  390. return tuple(v for v in [hidden_state, all_hidden_states] if v is not None)
  391. return BaseModelOutputWithNoAttention(last_hidden_state=hidden_state, hidden_states=all_hidden_states)
  392. class LevitClassificationLayer(nn.Module):
  393. """
  394. LeViT Classification Layer
  395. """
  396. def __init__(self, input_dim, output_dim):
  397. super().__init__()
  398. self.batch_norm = nn.BatchNorm1d(input_dim)
  399. self.linear = nn.Linear(input_dim, output_dim)
  400. def forward(self, hidden_state):
  401. hidden_state = self.batch_norm(hidden_state)
  402. logits = self.linear(hidden_state)
  403. return logits
  404. class LevitPreTrainedModel(PreTrainedModel):
  405. """
  406. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  407. models.
  408. """
  409. config_class = LevitConfig
  410. base_model_prefix = "levit"
  411. main_input_name = "pixel_values"
  412. _no_split_modules = ["LevitResidualLayer"]
  413. def _init_weights(self, module):
  414. """Initialize the weights"""
  415. if isinstance(module, (nn.Linear, nn.Conv2d)):
  416. # Slightly different from the TF version which uses truncated_normal for initialization
  417. # cf https://github.com/pytorch/pytorch/pull/5617
  418. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  419. if module.bias is not None:
  420. module.bias.data.zero_()
  421. elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
  422. module.bias.data.zero_()
  423. module.weight.data.fill_(1.0)
  424. LEVIT_START_DOCSTRING = r"""
  425. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  426. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  427. behavior.
  428. Parameters:
  429. config ([`LevitConfig`]): Model configuration class with all the parameters of the model.
  430. Initializing with a config file does not load the weights associated with the model, only the
  431. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  432. """
  433. LEVIT_INPUTS_DOCSTRING = r"""
  434. Args:
  435. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  436. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  437. [`LevitImageProcessor.__call__`] for details.
  438. output_hidden_states (`bool`, *optional*):
  439. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  440. more detail.
  441. return_dict (`bool`, *optional*):
  442. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  443. """
  444. @add_start_docstrings(
  445. "The bare Levit model outputting raw features without any specific head on top.",
  446. LEVIT_START_DOCSTRING,
  447. )
  448. class LevitModel(LevitPreTrainedModel):
  449. def __init__(self, config):
  450. super().__init__(config)
  451. self.config = config
  452. self.patch_embeddings = LevitPatchEmbeddings(config)
  453. self.encoder = LevitEncoder(config)
  454. # Initialize weights and apply final processing
  455. self.post_init()
  456. @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
  457. @add_code_sample_docstrings(
  458. checkpoint=_CHECKPOINT_FOR_DOC,
  459. output_type=BaseModelOutputWithPoolingAndNoAttention,
  460. config_class=_CONFIG_FOR_DOC,
  461. modality="vision",
  462. expected_output=_EXPECTED_OUTPUT_SHAPE,
  463. )
  464. def forward(
  465. self,
  466. pixel_values: torch.FloatTensor = None,
  467. output_hidden_states: Optional[bool] = None,
  468. return_dict: Optional[bool] = None,
  469. ) -> Union[Tuple, BaseModelOutputWithPoolingAndNoAttention]:
  470. output_hidden_states = (
  471. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  472. )
  473. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  474. if pixel_values is None:
  475. raise ValueError("You have to specify pixel_values")
  476. embeddings = self.patch_embeddings(pixel_values)
  477. encoder_outputs = self.encoder(
  478. embeddings,
  479. output_hidden_states=output_hidden_states,
  480. return_dict=return_dict,
  481. )
  482. last_hidden_state = encoder_outputs[0]
  483. # global average pooling, (batch_size, seq_length, hidden_sizes) -> (batch_size, hidden_sizes)
  484. pooled_output = last_hidden_state.mean(dim=1)
  485. if not return_dict:
  486. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  487. return BaseModelOutputWithPoolingAndNoAttention(
  488. last_hidden_state=last_hidden_state,
  489. pooler_output=pooled_output,
  490. hidden_states=encoder_outputs.hidden_states,
  491. )
  492. @add_start_docstrings(
  493. """
  494. Levit Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for
  495. ImageNet.
  496. """,
  497. LEVIT_START_DOCSTRING,
  498. )
  499. class LevitForImageClassification(LevitPreTrainedModel):
  500. def __init__(self, config):
  501. super().__init__(config)
  502. self.config = config
  503. self.num_labels = config.num_labels
  504. self.levit = LevitModel(config)
  505. # Classifier head
  506. self.classifier = (
  507. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  508. if config.num_labels > 0
  509. else torch.nn.Identity()
  510. )
  511. # Initialize weights and apply final processing
  512. self.post_init()
  513. @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
  514. @add_code_sample_docstrings(
  515. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  516. output_type=ImageClassifierOutputWithNoAttention,
  517. config_class=_CONFIG_FOR_DOC,
  518. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  519. )
  520. def forward(
  521. self,
  522. pixel_values: torch.FloatTensor = None,
  523. labels: Optional[torch.LongTensor] = None,
  524. output_hidden_states: Optional[bool] = None,
  525. return_dict: Optional[bool] = None,
  526. ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
  527. r"""
  528. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  529. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  530. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  531. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  532. """
  533. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  534. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  535. sequence_output = outputs[0]
  536. sequence_output = sequence_output.mean(1)
  537. logits = self.classifier(sequence_output)
  538. loss = None
  539. if labels is not None:
  540. if self.config.problem_type is None:
  541. if self.num_labels == 1:
  542. self.config.problem_type = "regression"
  543. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  544. self.config.problem_type = "single_label_classification"
  545. else:
  546. self.config.problem_type = "multi_label_classification"
  547. if self.config.problem_type == "regression":
  548. loss_fct = MSELoss()
  549. if self.num_labels == 1:
  550. loss = loss_fct(logits.squeeze(), labels.squeeze())
  551. else:
  552. loss = loss_fct(logits, labels)
  553. elif self.config.problem_type == "single_label_classification":
  554. loss_fct = CrossEntropyLoss()
  555. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  556. elif self.config.problem_type == "multi_label_classification":
  557. loss_fct = BCEWithLogitsLoss()
  558. loss = loss_fct(logits, labels)
  559. if not return_dict:
  560. output = (logits,) + outputs[2:]
  561. return ((loss,) + output) if loss is not None else output
  562. return ImageClassifierOutputWithNoAttention(
  563. loss=loss,
  564. logits=logits,
  565. hidden_states=outputs.hidden_states,
  566. )
  567. @add_start_docstrings(
  568. """
  569. LeViT Model transformer with image classification heads on top (a linear layer on top of the final hidden state and
  570. a linear layer on top of the final hidden state of the distillation token) e.g. for ImageNet. .. warning::
  571. This model supports inference-only. Fine-tuning with distillation (i.e. with a teacher) is not yet
  572. supported.
  573. """,
  574. LEVIT_START_DOCSTRING,
  575. )
  576. class LevitForImageClassificationWithTeacher(LevitPreTrainedModel):
  577. def __init__(self, config):
  578. super().__init__(config)
  579. self.config = config
  580. self.num_labels = config.num_labels
  581. self.levit = LevitModel(config)
  582. # Classifier head
  583. self.classifier = (
  584. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  585. if config.num_labels > 0
  586. else torch.nn.Identity()
  587. )
  588. self.classifier_distill = (
  589. LevitClassificationLayer(config.hidden_sizes[-1], config.num_labels)
  590. if config.num_labels > 0
  591. else torch.nn.Identity()
  592. )
  593. # Initialize weights and apply final processing
  594. self.post_init()
  595. @add_start_docstrings_to_model_forward(LEVIT_INPUTS_DOCSTRING)
  596. @add_code_sample_docstrings(
  597. checkpoint=_IMAGE_CLASS_CHECKPOINT,
  598. output_type=LevitForImageClassificationWithTeacherOutput,
  599. config_class=_CONFIG_FOR_DOC,
  600. expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
  601. )
  602. def forward(
  603. self,
  604. pixel_values: torch.FloatTensor = None,
  605. output_hidden_states: Optional[bool] = None,
  606. return_dict: Optional[bool] = None,
  607. ) -> Union[Tuple, LevitForImageClassificationWithTeacherOutput]:
  608. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  609. outputs = self.levit(pixel_values, output_hidden_states=output_hidden_states, return_dict=return_dict)
  610. sequence_output = outputs[0]
  611. sequence_output = sequence_output.mean(1)
  612. cls_logits, distill_logits = self.classifier(sequence_output), self.classifier_distill(sequence_output)
  613. logits = (cls_logits + distill_logits) / 2
  614. if not return_dict:
  615. output = (logits, cls_logits, distill_logits) + outputs[2:]
  616. return output
  617. return LevitForImageClassificationWithTeacherOutput(
  618. logits=logits,
  619. cls_logits=cls_logits,
  620. distillation_logits=distill_logits,
  621. hidden_states=outputs.hidden_states,
  622. )