modeling_videomae.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135
  1. # coding=utf-8
  2. # Copyright 2022 Multimedia Computing Group, Nanjing University and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch VideoMAE (masked autoencoder) model."""
  16. import collections.abc
  17. import math
  18. from copy import deepcopy
  19. from dataclasses import dataclass
  20. from typing import Optional, Set, Tuple, Union
  21. import numpy as np
  22. import torch
  23. import torch.utils.checkpoint
  24. from torch import nn
  25. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  26. from ...activations import ACT2FN
  27. from ...modeling_outputs import BaseModelOutput, ImageClassifierOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  30. from ...utils import (
  31. ModelOutput,
  32. add_start_docstrings,
  33. add_start_docstrings_to_model_forward,
  34. logging,
  35. replace_return_docstrings,
  36. )
  37. from ...utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  38. from .configuration_videomae import VideoMAEConfig
  39. logger = logging.get_logger(__name__)
  40. _CONFIG_FOR_DOC = "VideoMAEConfig"
  41. _CHECKPOINT_FOR_DOC = "MCG-NJU/videomae-base"
  42. @dataclass
  43. class VideoMAEDecoderOutput(ModelOutput):
  44. """
  45. Class for VideoMAEDecoder's outputs, with potential hidden states and attentions.
  46. Args:
  47. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  48. Pixel reconstruction logits.
  49. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  50. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  51. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
  52. plus the initial embedding outputs.
  53. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  54. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  55. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  56. the self-attention heads.
  57. """
  58. logits: torch.FloatTensor = None
  59. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  60. attentions: Optional[Tuple[torch.FloatTensor]] = None
  61. @dataclass
  62. class VideoMAEForPreTrainingOutput(ModelOutput):
  63. """
  64. Class for VideoMAEForPreTraining's outputs, with potential hidden states and attentions.
  65. Args:
  66. loss (`torch.FloatTensor` of shape `(1,)`):
  67. Pixel reconstruction loss.
  68. logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
  69. Pixel reconstruction logits.
  70. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  71. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
  72. shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
  73. plus the initial embedding outputs.
  74. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  75. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  76. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  77. the self-attention heads.
  78. """
  79. loss: Optional[torch.FloatTensor] = None
  80. logits: torch.FloatTensor = None
  81. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  82. attentions: Optional[Tuple[torch.FloatTensor]] = None
  83. # sin-cos position encoding
  84. # https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
  85. def get_sinusoid_encoding_table(n_position, d_hid):
  86. """Sinusoid position encoding table"""
  87. # TODO: make it with torch instead of numpy
  88. def get_position_angle_vec(position):
  89. return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
  90. sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
  91. sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
  92. sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
  93. return torch.FloatTensor(sinusoid_table).unsqueeze(0)
  94. class VideoMAEEmbeddings(nn.Module):
  95. """
  96. Construct the patch and position embeddings.
  97. """
  98. def __init__(self, config):
  99. super().__init__()
  100. self.patch_embeddings = VideoMAEPatchEmbeddings(config)
  101. self.num_patches = self.patch_embeddings.num_patches
  102. # fixed sin-cos embedding
  103. self.position_embeddings = get_sinusoid_encoding_table(self.num_patches, config.hidden_size)
  104. self.config = config
  105. def forward(self, pixel_values, bool_masked_pos):
  106. # create patch embeddings
  107. embeddings = self.patch_embeddings(pixel_values)
  108. # add position embeddings
  109. embeddings = embeddings + self.position_embeddings.type_as(embeddings).to(embeddings.device).clone().detach()
  110. # only keep visible patches
  111. # ~bool_masked_pos means visible
  112. if bool_masked_pos is not None:
  113. batch_size, _, num_channels = embeddings.shape
  114. embeddings = embeddings[~bool_masked_pos]
  115. embeddings = embeddings.reshape(batch_size, -1, num_channels)
  116. return embeddings
  117. class VideoMAEPatchEmbeddings(nn.Module):
  118. """
  119. Video to Patch Embedding. This module turns a batch of videos of shape (batch_size, num_frames, num_channels,
  120. height, width) into a tensor of shape (batch_size, seq_len, hidden_size) to be consumed by a Transformer encoder.
  121. The seq_len (the number of patches) equals (number of frames // tubelet_size) * (height // patch_size) * (width //
  122. patch_size).
  123. """
  124. def __init__(self, config):
  125. super().__init__()
  126. image_size = config.image_size
  127. patch_size = config.patch_size
  128. num_channels = config.num_channels
  129. hidden_size = config.hidden_size
  130. num_frames = config.num_frames
  131. tubelet_size = config.tubelet_size
  132. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  133. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  134. self.image_size = image_size
  135. self.patch_size = patch_size
  136. self.tubelet_size = int(tubelet_size)
  137. num_patches = (
  138. (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (num_frames // self.tubelet_size)
  139. )
  140. self.num_channels = num_channels
  141. self.num_patches = num_patches
  142. self.projection = nn.Conv3d(
  143. in_channels=num_channels,
  144. out_channels=hidden_size,
  145. kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
  146. stride=(self.tubelet_size, patch_size[0], patch_size[1]),
  147. )
  148. def forward(self, pixel_values):
  149. batch_size, num_frames, num_channels, height, width = pixel_values.shape
  150. if num_channels != self.num_channels:
  151. raise ValueError(
  152. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  153. )
  154. if height != self.image_size[0] or width != self.image_size[1]:
  155. raise ValueError(
  156. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  157. )
  158. # permute to (batch_size, num_channels, num_frames, height, width)
  159. pixel_values = pixel_values.permute(0, 2, 1, 3, 4)
  160. embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  161. return embeddings
  162. class VideoMAESelfAttention(nn.Module):
  163. def __init__(self, config: VideoMAEConfig) -> None:
  164. super().__init__()
  165. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  166. raise ValueError(
  167. f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
  168. f"heads {config.num_attention_heads}."
  169. )
  170. self.num_attention_heads = config.num_attention_heads
  171. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  172. self.all_head_size = self.num_attention_heads * self.attention_head_size
  173. self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  174. self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  175. self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  176. if config.qkv_bias:
  177. self.q_bias = nn.Parameter(torch.zeros(self.all_head_size))
  178. self.v_bias = nn.Parameter(torch.zeros(self.all_head_size))
  179. else:
  180. self.q_bias = None
  181. self.v_bias = None
  182. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  183. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  184. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  185. x = x.view(new_x_shape)
  186. return x.permute(0, 2, 1, 3)
  187. def forward(
  188. self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
  189. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  190. k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
  191. keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
  192. values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
  193. queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
  194. key_layer = self.transpose_for_scores(keys)
  195. value_layer = self.transpose_for_scores(values)
  196. query_layer = self.transpose_for_scores(queries)
  197. # Take the dot product between "query" and "key" to get the raw attention scores.
  198. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  199. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  200. # Normalize the attention scores to probabilities.
  201. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  202. # This is actually dropping out entire tokens to attend to, which might
  203. # seem a bit unusual, but is taken from the original Transformer paper.
  204. attention_probs = self.dropout(attention_probs)
  205. # Mask heads if we want to
  206. if head_mask is not None:
  207. attention_probs = attention_probs * head_mask
  208. context_layer = torch.matmul(attention_probs, value_layer)
  209. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  210. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  211. context_layer = context_layer.view(new_context_layer_shape)
  212. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  213. return outputs
  214. class VideoMAESdpaSelfAttention(VideoMAESelfAttention):
  215. def __init__(self, config: VideoMAEConfig) -> None:
  216. super().__init__(config)
  217. self.attention_probs_dropout_prob = config.attention_probs_dropout_prob
  218. def forward(
  219. self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
  220. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  221. k_bias = torch.zeros_like(self.v_bias, requires_grad=False) if self.q_bias is not None else None
  222. keys = nn.functional.linear(input=hidden_states, weight=self.key.weight, bias=k_bias)
  223. values = nn.functional.linear(input=hidden_states, weight=self.value.weight, bias=self.v_bias)
  224. queries = nn.functional.linear(input=hidden_states, weight=self.query.weight, bias=self.q_bias)
  225. key_layer = self.transpose_for_scores(keys)
  226. value_layer = self.transpose_for_scores(values)
  227. query_layer = self.transpose_for_scores(queries)
  228. context_layer = torch.nn.functional.scaled_dot_product_attention(
  229. query_layer,
  230. key_layer,
  231. value_layer,
  232. head_mask,
  233. self.attention_probs_dropout_prob if self.training else 0.0,
  234. is_causal=False,
  235. scale=None,
  236. )
  237. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  238. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  239. context_layer = context_layer.view(new_context_layer_shape)
  240. return context_layer, None
  241. # Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->VideoMAE
  242. class VideoMAESelfOutput(nn.Module):
  243. """
  244. The residual connection is defined in VideoMAELayer instead of here (as is the case with other models), due to the
  245. layernorm applied before each block.
  246. """
  247. def __init__(self, config: VideoMAEConfig) -> None:
  248. super().__init__()
  249. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  250. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  251. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  252. hidden_states = self.dense(hidden_states)
  253. hidden_states = self.dropout(hidden_states)
  254. return hidden_states
  255. # Copied from transformers.models.vit.modeling_vit.ViTAttention with ViT->VideoMAE
  256. class VideoMAEAttention(nn.Module):
  257. def __init__(self, config: VideoMAEConfig) -> None:
  258. super().__init__()
  259. self.attention = VideoMAESelfAttention(config)
  260. self.output = VideoMAESelfOutput(config)
  261. self.pruned_heads = set()
  262. def prune_heads(self, heads: Set[int]) -> None:
  263. if len(heads) == 0:
  264. return
  265. heads, index = find_pruneable_heads_and_indices(
  266. heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
  267. )
  268. # Prune linear layers
  269. self.attention.query = prune_linear_layer(self.attention.query, index)
  270. self.attention.key = prune_linear_layer(self.attention.key, index)
  271. self.attention.value = prune_linear_layer(self.attention.value, index)
  272. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  273. # Update hyper params and store pruned heads
  274. self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
  275. self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
  276. self.pruned_heads = self.pruned_heads.union(heads)
  277. def forward(
  278. self,
  279. hidden_states: torch.Tensor,
  280. head_mask: Optional[torch.Tensor] = None,
  281. output_attentions: bool = False,
  282. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  283. self_outputs = self.attention(hidden_states, head_mask, output_attentions)
  284. attention_output = self.output(self_outputs[0], hidden_states)
  285. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  286. return outputs
  287. # Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->VideoMAE
  288. class VideoMAESdpaAttention(VideoMAEAttention):
  289. def __init__(self, config: VideoMAEConfig) -> None:
  290. super().__init__(config)
  291. self.attention = VideoMAESdpaSelfAttention(config)
  292. # Copied from transformers.models.vit.modeling_vit.ViTIntermediate ViT->VideoMAE
  293. class VideoMAEIntermediate(nn.Module):
  294. def __init__(self, config: VideoMAEConfig) -> None:
  295. super().__init__()
  296. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  297. if isinstance(config.hidden_act, str):
  298. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  299. else:
  300. self.intermediate_act_fn = config.hidden_act
  301. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.intermediate_act_fn(hidden_states)
  304. return hidden_states
  305. # Copied from transformers.models.vit.modeling_vit.ViTOutput ViT->VideoMAE
  306. class VideoMAEOutput(nn.Module):
  307. def __init__(self, config: VideoMAEConfig) -> None:
  308. super().__init__()
  309. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  310. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  311. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  312. hidden_states = self.dense(hidden_states)
  313. hidden_states = self.dropout(hidden_states)
  314. hidden_states = hidden_states + input_tensor
  315. return hidden_states
  316. VIDEOMAE_ATTENTION_CLASSES = {"eager": VideoMAEAttention, "sdpa": VideoMAESdpaAttention}
  317. # Copied from transformers.models.vit.modeling_vit.ViTLayer with ViT->VideoMAE,VIT->VIDEOMAE
  318. class VideoMAELayer(nn.Module):
  319. """This corresponds to the Block class in the timm implementation."""
  320. def __init__(self, config: VideoMAEConfig) -> None:
  321. super().__init__()
  322. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  323. self.seq_len_dim = 1
  324. self.attention = VIDEOMAE_ATTENTION_CLASSES[config._attn_implementation](config)
  325. self.intermediate = VideoMAEIntermediate(config)
  326. self.output = VideoMAEOutput(config)
  327. self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  328. self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  329. def forward(
  330. self,
  331. hidden_states: torch.Tensor,
  332. head_mask: Optional[torch.Tensor] = None,
  333. output_attentions: bool = False,
  334. ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
  335. self_attention_outputs = self.attention(
  336. self.layernorm_before(hidden_states), # in VideoMAE, layernorm is applied before self-attention
  337. head_mask,
  338. output_attentions=output_attentions,
  339. )
  340. attention_output = self_attention_outputs[0]
  341. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  342. # first residual connection
  343. hidden_states = attention_output + hidden_states
  344. # in VideoMAE, layernorm is also applied after self-attention
  345. layer_output = self.layernorm_after(hidden_states)
  346. layer_output = self.intermediate(layer_output)
  347. # second residual connection is done here
  348. layer_output = self.output(layer_output, hidden_states)
  349. outputs = (layer_output,) + outputs
  350. return outputs
  351. # Copied from transformers.models.vit.modeling_vit.ViTEncoder with ViT->VideoMAE
  352. class VideoMAEEncoder(nn.Module):
  353. def __init__(self, config: VideoMAEConfig) -> None:
  354. super().__init__()
  355. self.config = config
  356. self.layer = nn.ModuleList([VideoMAELayer(config) for _ in range(config.num_hidden_layers)])
  357. self.gradient_checkpointing = False
  358. def forward(
  359. self,
  360. hidden_states: torch.Tensor,
  361. head_mask: Optional[torch.Tensor] = None,
  362. output_attentions: bool = False,
  363. output_hidden_states: bool = False,
  364. return_dict: bool = True,
  365. ) -> Union[tuple, BaseModelOutput]:
  366. all_hidden_states = () if output_hidden_states else None
  367. all_self_attentions = () if output_attentions else None
  368. for i, layer_module in enumerate(self.layer):
  369. if output_hidden_states:
  370. all_hidden_states = all_hidden_states + (hidden_states,)
  371. layer_head_mask = head_mask[i] if head_mask is not None else None
  372. if self.gradient_checkpointing and self.training:
  373. layer_outputs = self._gradient_checkpointing_func(
  374. layer_module.__call__,
  375. hidden_states,
  376. layer_head_mask,
  377. output_attentions,
  378. )
  379. else:
  380. layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
  381. hidden_states = layer_outputs[0]
  382. if output_attentions:
  383. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  384. if output_hidden_states:
  385. all_hidden_states = all_hidden_states + (hidden_states,)
  386. if not return_dict:
  387. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  388. return BaseModelOutput(
  389. last_hidden_state=hidden_states,
  390. hidden_states=all_hidden_states,
  391. attentions=all_self_attentions,
  392. )
  393. class VideoMAEPreTrainedModel(PreTrainedModel):
  394. """
  395. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  396. models.
  397. """
  398. config_class = VideoMAEConfig
  399. base_model_prefix = "videomae"
  400. main_input_name = "pixel_values"
  401. supports_gradient_checkpointing = True
  402. _supports_sdpa = True
  403. def _init_weights(self, module):
  404. """Initialize the weights"""
  405. if isinstance(module, (nn.Linear, nn.Conv3d)):
  406. # Slightly different from the TF version which uses truncated_normal for initialization
  407. # cf https://github.com/pytorch/pytorch/pull/5617
  408. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  409. if module.bias is not None:
  410. module.bias.data.zero_()
  411. elif isinstance(module, nn.LayerNorm):
  412. module.bias.data.zero_()
  413. module.weight.data.fill_(1.0)
  414. VIDEOMAE_START_DOCSTRING = r"""
  415. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
  416. as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  417. behavior.
  418. Parameters:
  419. config ([`VideoMAEConfig`]): Model configuration class with all the parameters of the model.
  420. Initializing with a config file does not load the weights associated with the model, only the
  421. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  422. """
  423. VIDEOMAE_INPUTS_DOCSTRING = r"""
  424. Args:
  425. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
  426. Pixel values. Pixel values can be obtained using [`AutoImageProcessor`]. See
  427. [`VideoMAEImageProcessor.__call__`] for details.
  428. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  429. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  430. - 1 indicates the head is **not masked**,
  431. - 0 indicates the head is **masked**.
  432. output_attentions (`bool`, *optional*):
  433. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  434. tensors for more detail.
  435. output_hidden_states (`bool`, *optional*):
  436. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  437. more detail.
  438. return_dict (`bool`, *optional*):
  439. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  440. """
  441. @add_start_docstrings(
  442. "The bare VideoMAE Model transformer outputting raw hidden-states without any specific head on top.",
  443. VIDEOMAE_START_DOCSTRING,
  444. )
  445. class VideoMAEModel(VideoMAEPreTrainedModel):
  446. def __init__(self, config):
  447. super().__init__(config)
  448. self.config = config
  449. self.embeddings = VideoMAEEmbeddings(config)
  450. self.encoder = VideoMAEEncoder(config)
  451. if config.use_mean_pooling:
  452. self.layernorm = None
  453. else:
  454. self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  455. # Initialize weights and apply final processing
  456. self.post_init()
  457. def get_input_embeddings(self):
  458. return self.embeddings.patch_embeddings
  459. def _prune_heads(self, heads_to_prune):
  460. """
  461. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  462. class PreTrainedModel
  463. """
  464. for layer, heads in heads_to_prune.items():
  465. self.encoder.layer[layer].attention.prune_heads(heads)
  466. @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
  467. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  468. def forward(
  469. self,
  470. pixel_values: torch.FloatTensor,
  471. bool_masked_pos: Optional[torch.BoolTensor] = None,
  472. head_mask: Optional[torch.Tensor] = None,
  473. output_attentions: Optional[bool] = None,
  474. output_hidden_states: Optional[bool] = None,
  475. return_dict: Optional[bool] = None,
  476. ) -> Union[Tuple, BaseModelOutput]:
  477. r"""
  478. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  479. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  480. batch must have the same number of masked patches. If `None`, then all patches are considered. Sequence
  481. length is `(num_frames // tubelet_size) * (image_size // patch_size) ** 2`.
  482. Returns:
  483. Examples:
  484. ```python
  485. >>> import av
  486. >>> import numpy as np
  487. >>> from transformers import AutoImageProcessor, VideoMAEModel
  488. >>> from huggingface_hub import hf_hub_download
  489. >>> np.random.seed(0)
  490. >>> def read_video_pyav(container, indices):
  491. ... '''
  492. ... Decode the video with PyAV decoder.
  493. ... Args:
  494. ... container (`av.container.input.InputContainer`): PyAV container.
  495. ... indices (`List[int]`): List of frame indices to decode.
  496. ... Returns:
  497. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  498. ... '''
  499. ... frames = []
  500. ... container.seek(0)
  501. ... start_index = indices[0]
  502. ... end_index = indices[-1]
  503. ... for i, frame in enumerate(container.decode(video=0)):
  504. ... if i > end_index:
  505. ... break
  506. ... if i >= start_index and i in indices:
  507. ... frames.append(frame)
  508. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  509. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  510. ... '''
  511. ... Sample a given number of frame indices from the video.
  512. ... Args:
  513. ... clip_len (`int`): Total number of frames to sample.
  514. ... frame_sample_rate (`int`): Sample every n-th frame.
  515. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  516. ... Returns:
  517. ... indices (`List[int]`): List of sampled frame indices
  518. ... '''
  519. ... converted_len = int(clip_len * frame_sample_rate)
  520. ... end_idx = np.random.randint(converted_len, seg_len)
  521. ... start_idx = end_idx - converted_len
  522. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  523. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  524. ... return indices
  525. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  526. >>> file_path = hf_hub_download(
  527. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  528. ... )
  529. >>> container = av.open(file_path)
  530. >>> # sample 16 frames
  531. >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  532. >>> video = read_video_pyav(container, indices)
  533. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
  534. >>> model = VideoMAEModel.from_pretrained("MCG-NJU/videomae-base")
  535. >>> # prepare video for the model
  536. >>> inputs = image_processor(list(video), return_tensors="pt")
  537. >>> # forward pass
  538. >>> outputs = model(**inputs)
  539. >>> last_hidden_states = outputs.last_hidden_state
  540. >>> list(last_hidden_states.shape)
  541. [1, 1568, 768]
  542. ```"""
  543. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  544. output_hidden_states = (
  545. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  546. )
  547. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  548. # Prepare head mask if needed
  549. # 1.0 in head_mask indicate we keep the head
  550. # attention_probs has shape bsz x n_heads x N x N
  551. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  552. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  553. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  554. embedding_output = self.embeddings(pixel_values, bool_masked_pos)
  555. encoder_outputs = self.encoder(
  556. embedding_output,
  557. head_mask=head_mask,
  558. output_attentions=output_attentions,
  559. output_hidden_states=output_hidden_states,
  560. return_dict=return_dict,
  561. )
  562. sequence_output = encoder_outputs[0]
  563. if self.layernorm is not None:
  564. sequence_output = self.layernorm(sequence_output)
  565. if not return_dict:
  566. return (sequence_output,) + encoder_outputs[1:]
  567. return BaseModelOutput(
  568. last_hidden_state=sequence_output,
  569. hidden_states=encoder_outputs.hidden_states,
  570. attentions=encoder_outputs.attentions,
  571. )
  572. class VideoMAEDecoder(nn.Module):
  573. def __init__(self, config, num_patches):
  574. super().__init__()
  575. decoder_num_labels = config.num_channels * config.tubelet_size * config.patch_size**2
  576. decoder_config = deepcopy(config)
  577. decoder_config.hidden_size = config.decoder_hidden_size
  578. decoder_config.num_hidden_layers = config.decoder_num_hidden_layers
  579. decoder_config.num_attention_heads = config.decoder_num_attention_heads
  580. decoder_config.intermediate_size = config.decoder_intermediate_size
  581. self.decoder_layers = nn.ModuleList(
  582. [VideoMAELayer(decoder_config) for _ in range(config.decoder_num_hidden_layers)]
  583. )
  584. self.norm = nn.LayerNorm(config.decoder_hidden_size)
  585. self.head = (
  586. nn.Linear(config.decoder_hidden_size, decoder_num_labels) if decoder_num_labels > 0 else nn.Identity()
  587. )
  588. self.gradient_checkpointing = False
  589. self.config = config
  590. def forward(
  591. self,
  592. hidden_states,
  593. return_token_num,
  594. output_attentions=False,
  595. output_hidden_states=False,
  596. return_dict=True,
  597. ):
  598. # apply Transformer layers (blocks)
  599. all_hidden_states = () if output_hidden_states else None
  600. all_self_attentions = () if output_attentions else None
  601. for i, layer_module in enumerate(self.decoder_layers):
  602. if output_hidden_states:
  603. all_hidden_states = all_hidden_states + (hidden_states,)
  604. if self.gradient_checkpointing and self.training:
  605. layer_outputs = self._gradient_checkpointing_func(
  606. layer_module.__call__,
  607. hidden_states,
  608. None,
  609. output_attentions,
  610. )
  611. else:
  612. layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
  613. hidden_states = layer_outputs[0]
  614. if output_attentions:
  615. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  616. if output_hidden_states:
  617. all_hidden_states = all_hidden_states + (hidden_states,)
  618. if return_token_num > 0:
  619. hidden_states = hidden_states[:, -return_token_num:]
  620. # predictor projection
  621. hidden_states = self.norm(hidden_states)
  622. logits = self.head(hidden_states)
  623. if not return_dict:
  624. return tuple(v for v in [logits, all_hidden_states, all_self_attentions] if v is not None)
  625. return VideoMAEDecoderOutput(logits=logits, hidden_states=all_hidden_states, attentions=all_self_attentions)
  626. @add_start_docstrings(
  627. "The VideoMAE Model transformer with the decoder on top for self-supervised pre-training.",
  628. VIDEOMAE_START_DOCSTRING,
  629. )
  630. class VideoMAEForPreTraining(VideoMAEPreTrainedModel):
  631. def __init__(self, config):
  632. super().__init__(config)
  633. self.config = config
  634. self.videomae = VideoMAEModel(config)
  635. self.encoder_to_decoder = nn.Linear(config.hidden_size, config.decoder_hidden_size, bias=False)
  636. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.decoder_hidden_size))
  637. self.position_embeddings = get_sinusoid_encoding_table(
  638. self.videomae.embeddings.num_patches, config.decoder_hidden_size
  639. )
  640. self.decoder = VideoMAEDecoder(config, num_patches=self.videomae.embeddings.num_patches)
  641. # Initialize weights and apply final processing
  642. self.post_init()
  643. @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
  644. @replace_return_docstrings(output_type=VideoMAEForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
  645. def forward(
  646. self,
  647. pixel_values: torch.FloatTensor,
  648. bool_masked_pos: torch.BoolTensor,
  649. head_mask: Optional[torch.Tensor] = None,
  650. output_attentions: Optional[bool] = None,
  651. output_hidden_states: Optional[bool] = None,
  652. return_dict: Optional[bool] = None,
  653. ) -> Union[tuple, VideoMAEForPreTrainingOutput]:
  654. r"""
  655. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, sequence_length)`):
  656. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). Each video in the
  657. batch must have the same number of masked patches. Sequence length is `(num_frames // tubelet_size) *
  658. (image_size // patch_size) ** 2`.
  659. Returns:
  660. Examples:
  661. ```python
  662. >>> from transformers import AutoImageProcessor, VideoMAEForPreTraining
  663. >>> import numpy as np
  664. >>> import torch
  665. >>> num_frames = 16
  666. >>> video = list(np.random.randint(0, 256, (num_frames, 3, 224, 224)))
  667. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
  668. >>> model = VideoMAEForPreTraining.from_pretrained("MCG-NJU/videomae-base")
  669. >>> pixel_values = image_processor(video, return_tensors="pt").pixel_values
  670. >>> num_patches_per_frame = (model.config.image_size // model.config.patch_size) ** 2
  671. >>> seq_length = (num_frames // model.config.tubelet_size) * num_patches_per_frame
  672. >>> bool_masked_pos = torch.randint(0, 2, (1, seq_length)).bool()
  673. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  674. >>> loss = outputs.loss
  675. ```"""
  676. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  677. outputs = self.videomae(
  678. pixel_values,
  679. bool_masked_pos=bool_masked_pos,
  680. head_mask=head_mask,
  681. output_attentions=output_attentions,
  682. output_hidden_states=output_hidden_states,
  683. return_dict=return_dict,
  684. )
  685. sequence_output = outputs[0]
  686. sequence_output = self.encoder_to_decoder(
  687. sequence_output
  688. ) # [batch_size, num_visible_patches, decoder_hidden_size]
  689. batch_size, seq_len, num_channels = sequence_output.shape
  690. # we don't unshuffle the correct visible token order, but shuffle the position embeddings accordingly.
  691. if bool_masked_pos is None:
  692. raise ValueError("One must provided a boolean mask ")
  693. expanded_position_embeddings = self.position_embeddings.expand(batch_size, -1, -1).type_as(pixel_values)
  694. expanded_position_embeddings = expanded_position_embeddings.to(pixel_values.device).clone().detach()
  695. pos_emb_visible = expanded_position_embeddings[~bool_masked_pos].reshape(batch_size, -1, num_channels)
  696. pos_emb_mask = expanded_position_embeddings[bool_masked_pos].reshape(batch_size, -1, num_channels)
  697. # [batch_size, num_patches, decoder_hidden_size]
  698. x_full = torch.cat([sequence_output + pos_emb_visible, self.mask_token + pos_emb_mask], dim=1)
  699. # [batch_size, num_masked_patches, num_channels * patch_size * patch_size]
  700. decoder_outputs = self.decoder(x_full, pos_emb_mask.shape[1])
  701. logits = decoder_outputs.logits
  702. loss = None
  703. with torch.no_grad():
  704. # calculate the labels to be predicted
  705. if self.config.num_channels != 3:
  706. # Can't unnormalize with default means/stds
  707. frames = pixel_values
  708. else:
  709. # first, unnormalize the frames
  710. device = pixel_values.device
  711. dtype = pixel_values.dtype
  712. mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device=device, dtype=dtype)[None, None, :, None, None]
  713. std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device=device, dtype=dtype)[None, None, :, None, None]
  714. frames = pixel_values * std + mean # in [0, 1]
  715. batch_size, time, num_channels, height, width = frames.shape
  716. tubelet_size, patch_size = self.config.tubelet_size, self.config.patch_size
  717. if self.config.norm_pix_loss:
  718. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  719. frames = frames.view(
  720. batch_size,
  721. time // tubelet_size,
  722. tubelet_size,
  723. num_channels,
  724. height // patch_size,
  725. patch_size,
  726. width // patch_size,
  727. patch_size,
  728. )
  729. # step 2: move dimensions to concatenate:
  730. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  731. # step 3: concatenate:
  732. frames = frames.view(
  733. batch_size,
  734. time // tubelet_size * height // patch_size * width // patch_size,
  735. tubelet_size * patch_size * patch_size,
  736. num_channels,
  737. )
  738. # step 4: normalize. The authors find that the mean is about 0.48 and standard deviation is about 0.08.
  739. frames_norm = (frames - frames.mean(dim=-2, keepdim=True)) / (
  740. frames.var(dim=-2, unbiased=True, keepdim=True).sqrt() + 1e-6
  741. )
  742. # step 5: reshape to (batch_size, T//ts * H//ps * W//ps, ts * ps * ps * C)
  743. videos_patch = frames_norm.view(
  744. batch_size,
  745. time // tubelet_size * height // patch_size * width // patch_size,
  746. tubelet_size * patch_size * patch_size * num_channels,
  747. )
  748. else:
  749. if self.config.num_channels != 3:
  750. raise ValueError(
  751. "Can't unnormalize non-RGB images. Consider setting config.norm_pix_loss to False."
  752. )
  753. # step 1: split up dimensions (time by tubelet_size, height by patch_size, width by patch_size)
  754. frames = frames.view(
  755. batch_size,
  756. time // tubelet_size,
  757. tubelet_size,
  758. num_channels,
  759. height // patch_size,
  760. patch_size,
  761. width // patch_size,
  762. patch_size,
  763. )
  764. # step 2: move dimensions to concatenate: (batch_size, T//ts, H//ps, W//ps, ts, ps, ps, C)
  765. frames = frames.permute(0, 1, 4, 6, 2, 5, 7, 3).contiguous()
  766. # step 3: concatenate
  767. videos_patch = frames.view(
  768. batch_size,
  769. time // tubelet_size * height // patch_size * width // patch_size,
  770. tubelet_size * patch_size * patch_size * num_channels,
  771. )
  772. batch_size, _, num_channels = videos_patch.shape
  773. labels = videos_patch[bool_masked_pos].reshape(batch_size, -1, num_channels)
  774. loss_fct = MSELoss()
  775. loss = loss_fct(logits, labels)
  776. if not return_dict:
  777. output = (logits,) + outputs[1:]
  778. return ((loss,) + output) if loss is not None else output
  779. return VideoMAEForPreTrainingOutput(
  780. loss=loss,
  781. logits=logits,
  782. hidden_states=outputs.hidden_states,
  783. attentions=outputs.attentions,
  784. )
  785. @add_start_docstrings(
  786. """VideoMAE Model transformer with a video classification head on top (a linear layer on top of the average pooled hidden
  787. states of all tokens) e.g. for ImageNet.""",
  788. VIDEOMAE_START_DOCSTRING,
  789. )
  790. class VideoMAEForVideoClassification(VideoMAEPreTrainedModel):
  791. def __init__(self, config):
  792. super().__init__(config)
  793. self.num_labels = config.num_labels
  794. self.videomae = VideoMAEModel(config)
  795. # Classifier head
  796. self.fc_norm = nn.LayerNorm(config.hidden_size) if config.use_mean_pooling else None
  797. self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
  798. # Initialize weights and apply final processing
  799. self.post_init()
  800. @add_start_docstrings_to_model_forward(VIDEOMAE_INPUTS_DOCSTRING)
  801. @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
  802. def forward(
  803. self,
  804. pixel_values: Optional[torch.Tensor] = None,
  805. head_mask: Optional[torch.Tensor] = None,
  806. labels: Optional[torch.Tensor] = None,
  807. output_attentions: Optional[bool] = None,
  808. output_hidden_states: Optional[bool] = None,
  809. return_dict: Optional[bool] = None,
  810. ) -> Union[Tuple, ImageClassifierOutput]:
  811. r"""
  812. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  813. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  814. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  815. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  816. Returns:
  817. Examples:
  818. ```python
  819. >>> import av
  820. >>> import torch
  821. >>> import numpy as np
  822. >>> from transformers import AutoImageProcessor, VideoMAEForVideoClassification
  823. >>> from huggingface_hub import hf_hub_download
  824. >>> np.random.seed(0)
  825. >>> def read_video_pyav(container, indices):
  826. ... '''
  827. ... Decode the video with PyAV decoder.
  828. ... Args:
  829. ... container (`av.container.input.InputContainer`): PyAV container.
  830. ... indices (`List[int]`): List of frame indices to decode.
  831. ... Returns:
  832. ... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
  833. ... '''
  834. ... frames = []
  835. ... container.seek(0)
  836. ... start_index = indices[0]
  837. ... end_index = indices[-1]
  838. ... for i, frame in enumerate(container.decode(video=0)):
  839. ... if i > end_index:
  840. ... break
  841. ... if i >= start_index and i in indices:
  842. ... frames.append(frame)
  843. ... return np.stack([x.to_ndarray(format="rgb24") for x in frames])
  844. >>> def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
  845. ... '''
  846. ... Sample a given number of frame indices from the video.
  847. ... Args:
  848. ... clip_len (`int`): Total number of frames to sample.
  849. ... frame_sample_rate (`int`): Sample every n-th frame.
  850. ... seg_len (`int`): Maximum allowed index of sample's last frame.
  851. ... Returns:
  852. ... indices (`List[int]`): List of sampled frame indices
  853. ... '''
  854. ... converted_len = int(clip_len * frame_sample_rate)
  855. ... end_idx = np.random.randint(converted_len, seg_len)
  856. ... start_idx = end_idx - converted_len
  857. ... indices = np.linspace(start_idx, end_idx, num=clip_len)
  858. ... indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
  859. ... return indices
  860. >>> # video clip consists of 300 frames (10 seconds at 30 FPS)
  861. >>> file_path = hf_hub_download(
  862. ... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
  863. ... )
  864. >>> container = av.open(file_path)
  865. >>> # sample 16 frames
  866. >>> indices = sample_frame_indices(clip_len=16, frame_sample_rate=1, seg_len=container.streams.video[0].frames)
  867. >>> video = read_video_pyav(container, indices)
  868. >>> image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  869. >>> model = VideoMAEForVideoClassification.from_pretrained("MCG-NJU/videomae-base-finetuned-kinetics")
  870. >>> inputs = image_processor(list(video), return_tensors="pt")
  871. >>> with torch.no_grad():
  872. ... outputs = model(**inputs)
  873. ... logits = outputs.logits
  874. >>> # model predicts one of the 400 Kinetics-400 classes
  875. >>> predicted_label = logits.argmax(-1).item()
  876. >>> print(model.config.id2label[predicted_label])
  877. eating spaghetti
  878. ```"""
  879. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  880. outputs = self.videomae(
  881. pixel_values,
  882. head_mask=head_mask,
  883. output_attentions=output_attentions,
  884. output_hidden_states=output_hidden_states,
  885. return_dict=return_dict,
  886. )
  887. sequence_output = outputs[0]
  888. if self.fc_norm is not None:
  889. sequence_output = self.fc_norm(sequence_output.mean(1))
  890. else:
  891. sequence_output = sequence_output[:, 0]
  892. logits = self.classifier(sequence_output)
  893. loss = None
  894. if labels is not None:
  895. if self.config.problem_type is None:
  896. if self.num_labels == 1:
  897. self.config.problem_type = "regression"
  898. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  899. self.config.problem_type = "single_label_classification"
  900. else:
  901. self.config.problem_type = "multi_label_classification"
  902. if self.config.problem_type == "regression":
  903. loss_fct = MSELoss()
  904. if self.num_labels == 1:
  905. loss = loss_fct(logits.squeeze(), labels.squeeze())
  906. else:
  907. loss = loss_fct(logits, labels)
  908. elif self.config.problem_type == "single_label_classification":
  909. loss_fct = CrossEntropyLoss()
  910. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  911. elif self.config.problem_type == "multi_label_classification":
  912. loss_fct = BCEWithLogitsLoss()
  913. loss = loss_fct(logits, labels)
  914. if not return_dict:
  915. output = (logits,) + outputs[1:]
  916. return ((loss,) + output) if loss is not None else output
  917. return ImageClassifierOutput(
  918. loss=loss,
  919. logits=logits,
  920. hidden_states=outputs.hidden_states,
  921. attentions=outputs.attentions,
  922. )