modeling_data2vec_audio.py 77 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Data2VecAudio model."""
  16. import math
  17. import warnings
  18. from typing import Optional, Tuple, Union
  19. import numpy as np
  20. import torch
  21. import torch.utils.checkpoint
  22. from torch import nn
  23. from torch.nn import CrossEntropyLoss
  24. from ...activations import ACT2FN
  25. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  26. from ...integrations.fsdp import is_fsdp_managed_module
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. CausalLMOutput,
  30. SequenceClassifierOutput,
  31. TokenClassifierOutput,
  32. Wav2Vec2BaseModelOutput,
  33. XVectorOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import (
  37. add_code_sample_docstrings,
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. is_flash_attn_2_available,
  41. is_flash_attn_greater_or_equal_2_10,
  42. is_peft_available,
  43. logging,
  44. )
  45. from .configuration_data2vec_audio import Data2VecAudioConfig
  46. if is_flash_attn_2_available():
  47. from ...modeling_flash_attention_utils import _flash_attention_forward
  48. logger = logging.get_logger(__name__)
  49. _HIDDEN_STATES_START_POSITION = 2
  50. # General docstring
  51. _CONFIG_FOR_DOC = "Data2VecAudioConfig"
  52. # Base docstring
  53. _CHECKPOINT_FOR_DOC = "facebook/data2vec-audio-base-960h"
  54. _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
  55. # CTC docstring
  56. _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
  57. _CTC_EXPECTED_LOSS = 66.95
  58. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  59. def _compute_mask_indices(
  60. shape: Tuple[int, int],
  61. mask_prob: float,
  62. mask_length: int,
  63. attention_mask: Optional[torch.LongTensor] = None,
  64. min_masks: int = 0,
  65. ) -> np.ndarray:
  66. """
  67. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  68. ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  69. CPU as part of the preprocessing during training.
  70. Args:
  71. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  72. the first element is the batch size and the second element is the length of the axis to span.
  73. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  74. independently generated mask spans of length `mask_length` is computed by
  75. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  76. actual percentage will be smaller.
  77. mask_length: size of the mask
  78. min_masks: minimum number of masked spans
  79. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  80. each batch dimension.
  81. """
  82. batch_size, sequence_length = shape
  83. if mask_length < 1:
  84. raise ValueError("`mask_length` has to be bigger than 0.")
  85. if mask_length > sequence_length:
  86. raise ValueError(
  87. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  88. f" and `sequence_length`: {sequence_length}`"
  89. )
  90. # epsilon is used for probabilistic rounding
  91. epsilon = np.random.rand(1).item()
  92. def compute_num_masked_span(input_length):
  93. """Given input length, compute how many spans should be masked"""
  94. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  95. num_masked_span = max(num_masked_span, min_masks)
  96. # make sure num masked span <= sequence_length
  97. if num_masked_span * mask_length > sequence_length:
  98. num_masked_span = sequence_length // mask_length
  99. # make sure num_masked span is also <= input_length - (mask_length - 1)
  100. if input_length - (mask_length - 1) < num_masked_span:
  101. num_masked_span = max(input_length - (mask_length - 1), 0)
  102. return num_masked_span
  103. # compute number of masked spans in batch
  104. input_lengths = (
  105. attention_mask.sum(-1).detach().tolist()
  106. if attention_mask is not None
  107. else [sequence_length for _ in range(batch_size)]
  108. )
  109. # SpecAugment mask to fill
  110. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  111. spec_aug_mask_idxs = []
  112. max_num_masked_span = compute_num_masked_span(sequence_length)
  113. if max_num_masked_span == 0:
  114. return spec_aug_mask
  115. for input_length in input_lengths:
  116. # compute num of masked spans for this input
  117. num_masked_span = compute_num_masked_span(input_length)
  118. # get random indices to mask
  119. spec_aug_mask_idx = np.random.choice(
  120. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  121. )
  122. # pick first sampled index that will serve as a dummy index to pad vector
  123. # to ensure same dimension for all batches due to probabilistic rounding
  124. # Picking first sample just pads those vectors twice.
  125. if len(spec_aug_mask_idx) == 0:
  126. # this case can only happen if `input_length` is strictly smaller then
  127. # `sequence_length` in which case the last token has to be a padding
  128. # token which we can use as a dummy mask id
  129. dummy_mask_idx = sequence_length - 1
  130. else:
  131. dummy_mask_idx = spec_aug_mask_idx[0]
  132. spec_aug_mask_idx = np.concatenate(
  133. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  134. )
  135. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  136. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  137. # expand masked indices to masked spans
  138. spec_aug_mask_idxs = np.broadcast_to(
  139. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  140. )
  141. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  142. # add offset to the starting indexes so that indexes now create a span
  143. offsets = np.arange(mask_length)[None, None, :]
  144. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  145. batch_size, max_num_masked_span * mask_length
  146. )
  147. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  148. # ensure that we cannot have indices larger than sequence_length
  149. if spec_aug_mask_idxs.max() > sequence_length - 1:
  150. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  151. # scatter indices to mask
  152. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  153. return spec_aug_mask
  154. class Data2VecAudioConvLayer(nn.Module):
  155. def __init__(self, config, layer_id=0):
  156. super().__init__()
  157. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  158. self.out_conv_dim = config.conv_dim[layer_id]
  159. self.conv = nn.Conv1d(
  160. self.in_conv_dim,
  161. self.out_conv_dim,
  162. kernel_size=config.conv_kernel[layer_id],
  163. stride=config.conv_stride[layer_id],
  164. bias=config.conv_bias,
  165. )
  166. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  167. self.activation = ACT2FN[config.feat_extract_activation]
  168. def forward(self, hidden_states):
  169. hidden_states = self.conv(hidden_states)
  170. hidden_states = hidden_states.transpose(-2, -1)
  171. hidden_states = self.layer_norm(hidden_states)
  172. hidden_states = hidden_states.transpose(-2, -1)
  173. hidden_states = self.activation(hidden_states)
  174. return hidden_states
  175. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Data2VecAudio
  176. class Data2VecAudioPadLayer(nn.Module):
  177. def __init__(self, num_conv_pos_embeddings):
  178. super().__init__()
  179. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  180. def forward(self, hidden_states):
  181. if self.num_pad_remove > 0:
  182. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  183. return hidden_states
  184. class Data2VecAudioPositionalConvLayer(nn.Module):
  185. def __init__(self, config):
  186. super().__init__()
  187. self.conv = nn.Conv1d(
  188. config.hidden_size,
  189. config.hidden_size,
  190. kernel_size=config.conv_pos_kernel_size,
  191. padding=config.conv_pos_kernel_size // 2,
  192. groups=config.num_conv_pos_embedding_groups,
  193. )
  194. self.padding = Data2VecAudioPadLayer(config.conv_pos_kernel_size)
  195. self.activation = ACT2FN[config.feat_extract_activation]
  196. # no learnable parameters
  197. self.layer_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False)
  198. def forward(self, hidden_states):
  199. hidden_states = self.conv(hidden_states)
  200. hidden_states = self.padding(hidden_states)
  201. hidden_states = hidden_states.transpose(1, 2)
  202. hidden_states = self.layer_norm(hidden_states)
  203. hidden_states = hidden_states.transpose(1, 2)
  204. hidden_states = self.activation(hidden_states)
  205. return hidden_states
  206. class Data2VecAudioPositionalConvEmbedding(nn.Module):
  207. def __init__(self, config):
  208. super().__init__()
  209. self.layers = nn.ModuleList(
  210. [Data2VecAudioPositionalConvLayer(config) for _ in range(config.num_conv_pos_embeddings)]
  211. )
  212. def forward(self, hidden_states):
  213. hidden_states = hidden_states.transpose(1, 2)
  214. for layer in self.layers:
  215. hidden_states = layer(hidden_states)
  216. hidden_states = hidden_states.transpose(1, 2)
  217. return hidden_states
  218. class Data2VecAudioFeatureEncoder(nn.Module):
  219. """Construct the features from raw audio waveform"""
  220. def __init__(self, config):
  221. super().__init__()
  222. self.conv_layers = nn.ModuleList(
  223. [Data2VecAudioConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  224. )
  225. self.gradient_checkpointing = False
  226. self._requires_grad = True
  227. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder._freeze_parameters
  228. def _freeze_parameters(self):
  229. for param in self.parameters():
  230. param.requires_grad = False
  231. self._requires_grad = False
  232. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder.forward
  233. def forward(self, input_values):
  234. hidden_states = input_values[:, None]
  235. # make sure hidden_states require grad for gradient_checkpointing
  236. if self._requires_grad and self.training:
  237. hidden_states.requires_grad = True
  238. for conv_layer in self.conv_layers:
  239. if self._requires_grad and self.gradient_checkpointing and self.training:
  240. hidden_states = self._gradient_checkpointing_func(
  241. conv_layer.__call__,
  242. hidden_states,
  243. )
  244. else:
  245. hidden_states = conv_layer(hidden_states)
  246. return hidden_states
  247. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Data2VecAudio
  248. class Data2VecAudioFeatureProjection(nn.Module):
  249. def __init__(self, config):
  250. super().__init__()
  251. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  252. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  253. self.dropout = nn.Dropout(config.feat_proj_dropout)
  254. def forward(self, hidden_states):
  255. # non-projected hidden states are needed for quantization
  256. norm_hidden_states = self.layer_norm(hidden_states)
  257. hidden_states = self.projection(norm_hidden_states)
  258. hidden_states = self.dropout(hidden_states)
  259. return hidden_states, norm_hidden_states
  260. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Data2VecAudio
  261. class Data2VecAudioAttention(nn.Module):
  262. """Multi-headed attention from 'Attention Is All You Need' paper"""
  263. def __init__(
  264. self,
  265. embed_dim: int,
  266. num_heads: int,
  267. dropout: float = 0.0,
  268. is_decoder: bool = False,
  269. bias: bool = True,
  270. is_causal: bool = False,
  271. config: Optional[Data2VecAudioConfig] = None,
  272. ):
  273. super().__init__()
  274. self.embed_dim = embed_dim
  275. self.num_heads = num_heads
  276. self.dropout = dropout
  277. self.head_dim = embed_dim // num_heads
  278. self.config = config
  279. if (self.head_dim * num_heads) != self.embed_dim:
  280. raise ValueError(
  281. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  282. f" and `num_heads`: {num_heads})."
  283. )
  284. self.scaling = self.head_dim**-0.5
  285. self.is_decoder = is_decoder
  286. self.is_causal = is_causal
  287. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  288. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  289. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  290. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  291. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  292. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  293. def forward(
  294. self,
  295. hidden_states: torch.Tensor,
  296. key_value_states: Optional[torch.Tensor] = None,
  297. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  298. attention_mask: Optional[torch.Tensor] = None,
  299. layer_head_mask: Optional[torch.Tensor] = None,
  300. output_attentions: bool = False,
  301. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  302. """Input shape: Batch x Time x Channel"""
  303. # if key_value_states are provided this layer is used as a cross-attention layer
  304. # for the decoder
  305. is_cross_attention = key_value_states is not None
  306. bsz, tgt_len, _ = hidden_states.size()
  307. # get query proj
  308. query_states = self.q_proj(hidden_states) * self.scaling
  309. # get key, value proj
  310. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  311. # is checking that the `sequence_length` of the `past_key_value` is the same as
  312. # the provided `key_value_states` to support prefix tuning
  313. if (
  314. is_cross_attention
  315. and past_key_value is not None
  316. and past_key_value[0].shape[2] == key_value_states.shape[1]
  317. ):
  318. # reuse k,v, cross_attentions
  319. key_states = past_key_value[0]
  320. value_states = past_key_value[1]
  321. elif is_cross_attention:
  322. # cross_attentions
  323. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  324. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  325. elif past_key_value is not None:
  326. # reuse k, v, self_attention
  327. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  328. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  329. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  330. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  331. else:
  332. # self_attention
  333. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  334. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  335. if self.is_decoder:
  336. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  337. # Further calls to cross_attention layer can then reuse all cross-attention
  338. # key/value_states (first "if" case)
  339. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  340. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  341. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  342. # if encoder bi-directional self-attention `past_key_value` is always `None`
  343. past_key_value = (key_states, value_states)
  344. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  345. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  346. key_states = key_states.reshape(*proj_shape)
  347. value_states = value_states.reshape(*proj_shape)
  348. src_len = key_states.size(1)
  349. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  350. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  351. raise ValueError(
  352. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  353. f" {attn_weights.size()}"
  354. )
  355. if attention_mask is not None:
  356. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  357. raise ValueError(
  358. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  359. )
  360. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  361. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  362. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  363. if layer_head_mask is not None:
  364. if layer_head_mask.size() != (self.num_heads,):
  365. raise ValueError(
  366. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  367. f" {layer_head_mask.size()}"
  368. )
  369. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  370. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  371. if output_attentions:
  372. # this operation is a bit awkward, but it's required to
  373. # make sure that attn_weights keeps its gradient.
  374. # In order to do so, attn_weights have to be reshaped
  375. # twice and have to be reused in the following
  376. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  377. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  378. else:
  379. attn_weights_reshaped = None
  380. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  381. attn_output = torch.bmm(attn_probs, value_states)
  382. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  383. raise ValueError(
  384. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  385. f" {attn_output.size()}"
  386. )
  387. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  388. attn_output = attn_output.transpose(1, 2)
  389. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  390. # partitioned across GPUs when using tensor-parallelism.
  391. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  392. attn_output = self.out_proj(attn_output)
  393. return attn_output, attn_weights_reshaped, past_key_value
  394. # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio
  395. class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
  396. """
  397. Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
  398. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  399. flash attention and deal with padding tokens in case the input contains any of them.
  400. """
  401. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  402. def __init__(self, *args, **kwargs):
  403. super().__init__(*args, **kwargs)
  404. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  405. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  406. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  407. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  408. def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  409. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
  410. def forward(
  411. self,
  412. hidden_states: torch.Tensor,
  413. key_value_states: Optional[torch.Tensor] = None,
  414. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  415. attention_mask: Optional[torch.Tensor] = None,
  416. layer_head_mask: Optional[torch.Tensor] = None,
  417. output_attentions: bool = False,
  418. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  419. # Data2VecAudioFlashAttention2 attention does not support output_attentions
  420. if output_attentions:
  421. raise ValueError("Data2VecAudioFlashAttention2 attention does not support output_attentions")
  422. # if key_value_states are provided this layer is used as a cross-attention layer
  423. # for the decoder
  424. is_cross_attention = key_value_states is not None
  425. bsz, q_len, _ = hidden_states.size()
  426. # get query proj
  427. query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
  428. # get key, value proj
  429. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  430. # is checking that the `sequence_length` of the `past_key_value` is the same as
  431. # the provided `key_value_states` to support prefix tuning
  432. if (
  433. is_cross_attention
  434. and past_key_value is not None
  435. and past_key_value[0].shape[2] == key_value_states.shape[1]
  436. ):
  437. # reuse k,v, cross_attentions
  438. key_states = past_key_value[0].transpose(1, 2)
  439. value_states = past_key_value[1].transpose(1, 2)
  440. elif is_cross_attention:
  441. # cross_attentions
  442. key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
  443. value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
  444. elif past_key_value is not None:
  445. # reuse k, v, self_attention
  446. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  447. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  448. key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
  449. value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
  450. else:
  451. # self_attention
  452. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  453. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  454. if self.is_decoder:
  455. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  456. # Further calls to cross_attention layer can then reuse all cross-attention
  457. # key/value_states (first "if" case)
  458. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  459. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  460. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  461. # if encoder bi-directional self-attention `past_key_value` is always `None`
  462. past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
  463. kv_seq_len = key_states.shape[-2]
  464. if past_key_value is not None:
  465. kv_seq_len += past_key_value[0].shape[-2]
  466. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  467. # therefore the input hidden states gets silently casted in float32. Hence, we need
  468. # cast them back in the correct dtype just to be sure everything works as expected.
  469. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  470. # in fp32. (LlamaRMSNorm handles it correctly)
  471. input_dtype = query_states.dtype
  472. if input_dtype == torch.float32:
  473. if torch.is_autocast_enabled():
  474. target_dtype = torch.get_autocast_gpu_dtype()
  475. # Handle the case where the model is quantized
  476. elif hasattr(self.config, "_pre_quantization_dtype"):
  477. target_dtype = self.config._pre_quantization_dtype
  478. else:
  479. target_dtype = self.q_proj.weight.dtype
  480. logger.warning_once(
  481. f"The input hidden states seems to be silently casted in float32, this might be related to"
  482. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  483. f" {target_dtype}."
  484. )
  485. query_states = query_states.to(target_dtype)
  486. key_states = key_states.to(target_dtype)
  487. value_states = value_states.to(target_dtype)
  488. attn_output = _flash_attention_forward(
  489. query_states,
  490. key_states,
  491. value_states,
  492. attention_mask,
  493. q_len,
  494. dropout=self.dropout if self.training else 0.0,
  495. is_causal=self.is_causal,
  496. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  497. )
  498. attn_output = attn_output.reshape(bsz, q_len, -1)
  499. attn_output = self.out_proj(attn_output)
  500. if not output_attentions:
  501. attn_weights = None
  502. return attn_output, attn_weights, past_key_value
  503. class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
  504. # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio
  505. def forward(
  506. self,
  507. hidden_states: torch.Tensor,
  508. key_value_states: Optional[torch.Tensor] = None,
  509. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  510. attention_mask: Optional[torch.Tensor] = None,
  511. layer_head_mask: Optional[torch.Tensor] = None,
  512. output_attentions: bool = False,
  513. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  514. """Input shape: Batch x Time x Channel"""
  515. if output_attentions or layer_head_mask is not None:
  516. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
  517. logger.warning_once(
  518. "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
  519. ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  520. )
  521. return super().forward(
  522. hidden_states,
  523. key_value_states=key_value_states,
  524. past_key_value=past_key_value,
  525. attention_mask=attention_mask,
  526. layer_head_mask=layer_head_mask,
  527. output_attentions=output_attentions,
  528. )
  529. # if key_value_states are provided this layer is used as a cross-attention layer
  530. # for the decoder
  531. is_cross_attention = key_value_states is not None
  532. bsz, tgt_len, _ = hidden_states.size()
  533. # get query proj
  534. query_states = self.q_proj(hidden_states)
  535. # get key, value proj
  536. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  537. # is checking that the `sequence_length` of the `past_key_value` is the same as
  538. # the provided `key_value_states` to support prefix tuning
  539. if (
  540. is_cross_attention
  541. and past_key_value is not None
  542. and past_key_value[0].shape[2] == key_value_states.shape[1]
  543. ):
  544. # reuse k,v, cross_attentions
  545. key_states = past_key_value[0]
  546. value_states = past_key_value[1]
  547. elif is_cross_attention:
  548. # cross_attentions
  549. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  550. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  551. elif past_key_value is not None:
  552. # reuse k, v, self_attention
  553. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  554. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  555. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  556. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  557. else:
  558. # self_attention
  559. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  560. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  561. if self.is_decoder:
  562. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  563. # Further calls to cross_attention layer can then reuse all cross-attention
  564. # key/value_states (first "if" case)
  565. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  566. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  567. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  568. # if encoder bi-directional self-attention `past_key_value` is always `None`
  569. past_key_value = (key_states, value_states)
  570. query_states = self._shape(query_states, tgt_len, bsz)
  571. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  572. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  573. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
  574. is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
  575. # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
  576. # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
  577. attn_output = torch.nn.functional.scaled_dot_product_attention(
  578. query_states,
  579. key_states,
  580. value_states,
  581. attn_mask=attention_mask,
  582. dropout_p=self.dropout if self.training else 0.0,
  583. is_causal=is_causal,
  584. )
  585. if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
  586. raise ValueError(
  587. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  588. f" {attn_output.size()}"
  589. )
  590. attn_output = attn_output.transpose(1, 2)
  591. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  592. # partitioned across GPUs when using tensor-parallelism.
  593. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  594. attn_output = self.out_proj(attn_output)
  595. return attn_output, None, past_key_value
  596. DATA2VEC2AUDIO_ATTENTION_CLASSES = {
  597. "eager": Data2VecAudioAttention,
  598. "sdpa": Data2VecAudioSdpaAttention,
  599. "flash_attention_2": Data2VecAudioFlashAttention2,
  600. }
  601. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Data2VecAudio
  602. class Data2VecAudioFeedForward(nn.Module):
  603. def __init__(self, config):
  604. super().__init__()
  605. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  606. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  607. if isinstance(config.hidden_act, str):
  608. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  609. else:
  610. self.intermediate_act_fn = config.hidden_act
  611. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  612. self.output_dropout = nn.Dropout(config.hidden_dropout)
  613. def forward(self, hidden_states):
  614. hidden_states = self.intermediate_dense(hidden_states)
  615. hidden_states = self.intermediate_act_fn(hidden_states)
  616. hidden_states = self.intermediate_dropout(hidden_states)
  617. hidden_states = self.output_dense(hidden_states)
  618. hidden_states = self.output_dropout(hidden_states)
  619. return hidden_states
  620. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Data2VecAudio, WAV2VEC2->DATA2VEC2AUDIO
  621. class Data2VecAudioEncoderLayer(nn.Module):
  622. def __init__(self, config):
  623. super().__init__()
  624. self.attention = DATA2VEC2AUDIO_ATTENTION_CLASSES[config._attn_implementation](
  625. embed_dim=config.hidden_size,
  626. num_heads=config.num_attention_heads,
  627. dropout=config.attention_dropout,
  628. is_decoder=False,
  629. )
  630. self.dropout = nn.Dropout(config.hidden_dropout)
  631. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  632. self.feed_forward = Data2VecAudioFeedForward(config)
  633. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  634. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  635. attn_residual = hidden_states
  636. hidden_states, attn_weights, _ = self.attention(
  637. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  638. )
  639. hidden_states = self.dropout(hidden_states)
  640. hidden_states = attn_residual + hidden_states
  641. hidden_states = self.layer_norm(hidden_states)
  642. hidden_states = hidden_states + self.feed_forward(hidden_states)
  643. hidden_states = self.final_layer_norm(hidden_states)
  644. outputs = (hidden_states,)
  645. if output_attentions:
  646. outputs += (attn_weights,)
  647. return outputs
  648. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Data2VecAudio
  649. class Data2VecAudioEncoder(nn.Module):
  650. def __init__(self, config):
  651. super().__init__()
  652. self.config = config
  653. self.pos_conv_embed = Data2VecAudioPositionalConvEmbedding(config)
  654. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  655. self.dropout = nn.Dropout(config.hidden_dropout)
  656. self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  657. self.gradient_checkpointing = False
  658. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  659. def forward(
  660. self,
  661. hidden_states: torch.tensor,
  662. attention_mask: Optional[torch.Tensor] = None,
  663. output_attentions: bool = False,
  664. output_hidden_states: bool = False,
  665. return_dict: bool = True,
  666. ):
  667. all_hidden_states = () if output_hidden_states else None
  668. all_self_attentions = () if output_attentions else None
  669. if attention_mask is not None:
  670. # make sure padded tokens output 0
  671. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  672. hidden_states[~expand_attention_mask] = 0
  673. if self._use_flash_attention_2:
  674. # 2d mask is passed through the layers
  675. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  676. else:
  677. # extend attention_mask
  678. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  679. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  680. attention_mask = attention_mask.expand(
  681. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  682. )
  683. position_embeddings = self.pos_conv_embed(hidden_states)
  684. hidden_states = hidden_states + position_embeddings
  685. hidden_states = self.layer_norm(hidden_states)
  686. hidden_states = self.dropout(hidden_states)
  687. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  688. for layer in self.layers:
  689. if output_hidden_states:
  690. all_hidden_states = all_hidden_states + (hidden_states,)
  691. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  692. dropout_probability = torch.rand([])
  693. skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
  694. if not skip_the_layer or synced_gpus:
  695. # under fsdp or deepspeed zero3 all gpus must run in sync
  696. if self.gradient_checkpointing and self.training:
  697. layer_outputs = self._gradient_checkpointing_func(
  698. layer.__call__,
  699. hidden_states,
  700. attention_mask,
  701. output_attentions,
  702. )
  703. else:
  704. layer_outputs = layer(
  705. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  706. )
  707. hidden_states = layer_outputs[0]
  708. if skip_the_layer:
  709. layer_outputs = (None, None)
  710. if output_attentions:
  711. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  712. if output_hidden_states:
  713. all_hidden_states = all_hidden_states + (hidden_states,)
  714. if not return_dict:
  715. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  716. return BaseModelOutput(
  717. last_hidden_state=hidden_states,
  718. hidden_states=all_hidden_states,
  719. attentions=all_self_attentions,
  720. )
  721. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Data2VecAudio
  722. class Data2VecAudioAdapter(nn.Module):
  723. def __init__(self, config):
  724. super().__init__()
  725. # feature dim might need to be down-projected
  726. if config.output_hidden_size != config.hidden_size:
  727. self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
  728. self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
  729. else:
  730. self.proj = self.proj_layer_norm = None
  731. self.layers = nn.ModuleList(Data2VecAudioAdapterLayer(config) for _ in range(config.num_adapter_layers))
  732. self.layerdrop = config.layerdrop
  733. def forward(self, hidden_states):
  734. # down project hidden_states if necessary
  735. if self.proj is not None and self.proj_layer_norm is not None:
  736. hidden_states = self.proj(hidden_states)
  737. hidden_states = self.proj_layer_norm(hidden_states)
  738. hidden_states = hidden_states.transpose(1, 2)
  739. for layer in self.layers:
  740. layerdrop_prob = np.random.random()
  741. if not self.training or (layerdrop_prob > self.layerdrop):
  742. hidden_states = layer(hidden_states)
  743. hidden_states = hidden_states.transpose(1, 2)
  744. return hidden_states
  745. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Data2VecAudio
  746. class Data2VecAudioAdapterLayer(nn.Module):
  747. def __init__(self, config):
  748. super().__init__()
  749. self.conv = nn.Conv1d(
  750. config.output_hidden_size,
  751. 2 * config.output_hidden_size,
  752. config.adapter_kernel_size,
  753. stride=config.adapter_stride,
  754. padding=1,
  755. )
  756. def forward(self, hidden_states):
  757. hidden_states = self.conv(hidden_states)
  758. hidden_states = nn.functional.glu(hidden_states, dim=1)
  759. return hidden_states
  760. class Data2VecAudioPreTrainedModel(PreTrainedModel):
  761. """
  762. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  763. models.
  764. """
  765. config_class = Data2VecAudioConfig
  766. base_model_prefix = "data2vec_audio"
  767. main_input_name = "input_values"
  768. supports_gradient_checkpointing = True
  769. _supports_flash_attn_2 = True
  770. _supports_sdpa = True
  771. def _init_weights(self, module):
  772. """Initialize the weights"""
  773. if isinstance(module, Data2VecAudioFeatureProjection):
  774. k = math.sqrt(1 / module.projection.in_features)
  775. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  776. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  777. elif isinstance(module, Data2VecAudioPositionalConvLayer):
  778. nn.init.constant_(module.conv.bias, 0)
  779. elif isinstance(module, nn.Linear):
  780. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  781. if module.bias is not None:
  782. module.bias.data.zero_()
  783. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  784. if module.bias is not None:
  785. module.bias.data.zero_()
  786. if module.weight is not None:
  787. module.weight.data.fill_(1.0)
  788. elif isinstance(module, nn.Conv1d):
  789. nn.init.kaiming_normal_(module.weight)
  790. if module.bias is not None:
  791. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  792. nn.init.uniform_(module.bias, a=-k, b=k)
  793. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feat_extract_output_lengths with
  794. def _get_feat_extract_output_lengths(
  795. self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
  796. ):
  797. """
  798. Computes the output length of the convolutional layers
  799. """
  800. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  801. def _conv_out_length(input_length, kernel_size, stride):
  802. # 1D convolutional layer output length formula taken
  803. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  804. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  805. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  806. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  807. if add_adapter:
  808. for _ in range(self.config.num_adapter_layers):
  809. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  810. return input_lengths
  811. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel._get_feature_vector_attention_mask
  812. def _get_feature_vector_attention_mask(
  813. self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
  814. ):
  815. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  816. # on inference mode.
  817. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  818. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
  819. output_lengths = output_lengths.to(torch.long)
  820. batch_size = attention_mask.shape[0]
  821. attention_mask = torch.zeros(
  822. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  823. )
  824. # these two operations makes sure that all values before the output lengths idxs are attended to
  825. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  826. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  827. return attention_mask
  828. DATA2VEC_AUDIO_START_DOCSTRING = r"""
  829. Data2VecAudio was proposed in [data2vec: A General Framework for Self-supervised Learning in Speech, Vision and
  830. Language](https://arxiv.org/pdf/2202.03555) by Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu and
  831. Michael Auli.
  832. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  833. library implements for all its model (such as downloading or saving etc.).
  834. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  835. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  836. behavior.
  837. Parameters:
  838. config ([`Data2VecAudioConfig`]): Model configuration class with all the parameters of the model.
  839. Initializing with a config file does not load the weights associated with the model, only the
  840. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  841. """
  842. DATA2VEC_AUDIO_INPUTS_DOCSTRING = r"""
  843. Args:
  844. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  845. Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
  846. into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
  847. soundfile*). To prepare the array into *input_values*, the [`AutoProcessor`] should be used for padding and
  848. conversion into a tensor of type *torch.FloatTensor*. See [`Wav2Vec2Processor.__call__`] for details.
  849. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  850. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
  851. 1]`:
  852. - 1 for tokens that are **not masked**,
  853. - 0 for tokens that are **masked**.
  854. [What are attention masks?](../glossary#attention-mask)
  855. <Tip warning={true}>
  856. `attention_mask` should be passed if the corresponding processor has `config.return_attention_mask ==
  857. True`, which is the case for all pre-trained Data2Vec Audio models. Be aware that that even with
  858. `attention_mask`, zero-padded inputs will have slightly different outputs compared to non-padded inputs
  859. because there are more than one convolutional layer in the positional encodings. For a more detailed
  860. explanation, see [here](https://github.com/huggingface/transformers/issues/25621#issuecomment-1713759349).
  861. </Tip>
  862. output_attentions (`bool`, *optional*):
  863. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  864. tensors for more detail.
  865. output_hidden_states (`bool`, *optional*):
  866. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  867. more detail.
  868. return_dict (`bool`, *optional*):
  869. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  870. """
  871. @add_start_docstrings(
  872. "The bare Data2VecAudio Model transformer outputting raw hidden-states without any specific head on top.",
  873. DATA2VEC_AUDIO_START_DOCSTRING,
  874. )
  875. class Data2VecAudioModel(Data2VecAudioPreTrainedModel):
  876. def __init__(self, config: Data2VecAudioConfig):
  877. super().__init__(config)
  878. self.config = config
  879. self.feature_extractor = Data2VecAudioFeatureEncoder(config)
  880. self.feature_projection = Data2VecAudioFeatureProjection(config)
  881. # model only needs masking vector if mask prob is > 0.0
  882. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  883. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  884. self.encoder = Data2VecAudioEncoder(config)
  885. self.adapter = Data2VecAudioAdapter(config) if config.add_adapter else None
  886. # Initialize weights and apply final processing
  887. self.post_init()
  888. def freeze_feature_encoder(self):
  889. """
  890. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  891. not be updated during training.
  892. """
  893. self.feature_extractor._freeze_parameters()
  894. def _mask_hidden_states(
  895. self,
  896. hidden_states: torch.FloatTensor,
  897. mask_time_indices: Optional[torch.FloatTensor] = None,
  898. attention_mask: Optional[torch.LongTensor] = None,
  899. ):
  900. """
  901. Masks extracted features along time axis and/or along feature axis according to
  902. [SpecAugment](https://arxiv.org/abs/1904.08779).
  903. """
  904. # `config.apply_spec_augment` can set masking to False
  905. if not getattr(self.config, "apply_spec_augment", True):
  906. return hidden_states
  907. # generate indices & apply SpecAugment along time axis
  908. batch_size, sequence_length, hidden_size = hidden_states.size()
  909. if mask_time_indices is not None:
  910. # apply SpecAugment along time axis with given mask_time_indices
  911. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  912. elif self.config.mask_time_prob > 0 and self.training:
  913. mask_time_indices = _compute_mask_indices(
  914. (batch_size, sequence_length),
  915. mask_prob=self.config.mask_time_prob,
  916. mask_length=self.config.mask_time_length,
  917. attention_mask=attention_mask,
  918. min_masks=self.config.mask_time_min_masks,
  919. )
  920. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  921. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  922. if self.config.mask_feature_prob > 0 and self.training:
  923. # generate indices & apply SpecAugment along feature axis
  924. mask_feature_indices = _compute_mask_indices(
  925. (batch_size, hidden_size),
  926. mask_prob=self.config.mask_feature_prob,
  927. mask_length=self.config.mask_feature_length,
  928. min_masks=self.config.mask_feature_min_masks,
  929. )
  930. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  931. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  932. hidden_states[mask_feature_indices] = 0
  933. return hidden_states
  934. @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
  935. @add_code_sample_docstrings(
  936. checkpoint=_CHECKPOINT_FOR_DOC,
  937. output_type=Wav2Vec2BaseModelOutput,
  938. config_class=_CONFIG_FOR_DOC,
  939. modality="audio",
  940. expected_output=_EXPECTED_OUTPUT_SHAPE,
  941. )
  942. def forward(
  943. self,
  944. input_values: Optional[torch.Tensor],
  945. attention_mask: Optional[torch.Tensor] = None,
  946. mask_time_indices: Optional[torch.FloatTensor] = None,
  947. output_attentions: Optional[bool] = None,
  948. output_hidden_states: Optional[bool] = None,
  949. return_dict: Optional[bool] = None,
  950. ) -> Union[Tuple, Wav2Vec2BaseModelOutput]:
  951. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  952. output_hidden_states = (
  953. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  954. )
  955. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  956. extract_features = self.feature_extractor(input_values)
  957. extract_features = extract_features.transpose(1, 2)
  958. if attention_mask is not None:
  959. # compute reduced attention_mask corresponding to feature vectors
  960. attention_mask = self._get_feature_vector_attention_mask(
  961. extract_features.shape[1], attention_mask, add_adapter=False
  962. )
  963. hidden_states, extract_features = self.feature_projection(extract_features)
  964. hidden_states = self._mask_hidden_states(
  965. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  966. )
  967. encoder_outputs = self.encoder(
  968. hidden_states,
  969. attention_mask=attention_mask,
  970. output_attentions=output_attentions,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. )
  974. hidden_states = encoder_outputs[0]
  975. if self.adapter is not None:
  976. hidden_states = self.adapter(hidden_states)
  977. if not return_dict:
  978. return (hidden_states, extract_features) + encoder_outputs[1:]
  979. return Wav2Vec2BaseModelOutput(
  980. last_hidden_state=hidden_states,
  981. extract_features=extract_features,
  982. hidden_states=encoder_outputs.hidden_states,
  983. attentions=encoder_outputs.attentions,
  984. )
  985. @add_start_docstrings(
  986. """Data2VecAudio Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
  987. DATA2VEC_AUDIO_START_DOCSTRING,
  988. )
  989. class Data2VecAudioForCTC(Data2VecAudioPreTrainedModel):
  990. def __init__(self, config):
  991. super().__init__(config)
  992. self.data2vec_audio = Data2VecAudioModel(config)
  993. self.dropout = nn.Dropout(config.final_dropout)
  994. if config.vocab_size is None:
  995. raise ValueError(
  996. f"You are trying to instantiate {self.__class__} with a configuration that "
  997. "does not define the vocabulary size of the language model head. Please "
  998. "instantiate the model as follows: `Data2VecAudioForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  999. "or define `vocab_size` of your model's configuration."
  1000. )
  1001. output_hidden_size = (
  1002. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1003. )
  1004. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  1005. # Initialize weights and apply final processing
  1006. self.post_init()
  1007. def freeze_feature_extractor(self):
  1008. """
  1009. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1010. not be updated during training.
  1011. """
  1012. warnings.warn(
  1013. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1014. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1015. FutureWarning,
  1016. )
  1017. self.freeze_feature_encoder()
  1018. def freeze_feature_encoder(self):
  1019. """
  1020. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1021. not be updated during training.
  1022. """
  1023. self.data2vec_audio.feature_extractor._freeze_parameters()
  1024. @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
  1025. @add_code_sample_docstrings(
  1026. checkpoint=_CHECKPOINT_FOR_DOC,
  1027. output_type=CausalLMOutput,
  1028. config_class=_CONFIG_FOR_DOC,
  1029. expected_output=_CTC_EXPECTED_OUTPUT,
  1030. expected_loss=_CTC_EXPECTED_LOSS,
  1031. )
  1032. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with wav2vec2->data2vec_audio
  1033. def forward(
  1034. self,
  1035. input_values: Optional[torch.Tensor],
  1036. attention_mask: Optional[torch.Tensor] = None,
  1037. output_attentions: Optional[bool] = None,
  1038. output_hidden_states: Optional[bool] = None,
  1039. return_dict: Optional[bool] = None,
  1040. labels: Optional[torch.Tensor] = None,
  1041. ) -> Union[Tuple, CausalLMOutput]:
  1042. r"""
  1043. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1044. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1045. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1046. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1047. config.vocab_size - 1]`.
  1048. """
  1049. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1050. if labels is not None and labels.max() >= self.config.vocab_size:
  1051. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1052. outputs = self.data2vec_audio(
  1053. input_values,
  1054. attention_mask=attention_mask,
  1055. output_attentions=output_attentions,
  1056. output_hidden_states=output_hidden_states,
  1057. return_dict=return_dict,
  1058. )
  1059. hidden_states = outputs[0]
  1060. hidden_states = self.dropout(hidden_states)
  1061. logits = self.lm_head(hidden_states)
  1062. loss = None
  1063. if labels is not None:
  1064. # retrieve loss input_lengths from attention_mask
  1065. attention_mask = (
  1066. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1067. )
  1068. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1069. # assuming that padded tokens are filled with -100
  1070. # when not being attended to
  1071. labels_mask = labels >= 0
  1072. target_lengths = labels_mask.sum(-1)
  1073. flattened_targets = labels.masked_select(labels_mask)
  1074. # ctc_loss doesn't support fp16
  1075. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1076. with torch.backends.cudnn.flags(enabled=False):
  1077. loss = nn.functional.ctc_loss(
  1078. log_probs,
  1079. flattened_targets,
  1080. input_lengths,
  1081. target_lengths,
  1082. blank=self.config.pad_token_id,
  1083. reduction=self.config.ctc_loss_reduction,
  1084. zero_infinity=self.config.ctc_zero_infinity,
  1085. )
  1086. if not return_dict:
  1087. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1088. return ((loss,) + output) if loss is not None else output
  1089. return CausalLMOutput(
  1090. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1091. )
  1092. @add_start_docstrings(
  1093. """
  1094. Data2VecAudio Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
  1095. like SUPERB Keyword Spotting.
  1096. """,
  1097. DATA2VEC_AUDIO_START_DOCSTRING,
  1098. )
  1099. class Data2VecAudioForSequenceClassification(Data2VecAudioPreTrainedModel):
  1100. def __init__(self, config):
  1101. super().__init__(config)
  1102. if hasattr(config, "add_adapter") and config.add_adapter:
  1103. raise ValueError(
  1104. "Sequence classification does not support the use of Data2VecAudio adapters (config.add_adapter=True)"
  1105. )
  1106. self.data2vec_audio = Data2VecAudioModel(config)
  1107. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1108. if config.use_weighted_layer_sum:
  1109. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1110. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1111. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1112. # Initialize weights and apply final processing
  1113. self.post_init()
  1114. def freeze_feature_extractor(self):
  1115. """
  1116. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1117. not be updated during training.
  1118. """
  1119. warnings.warn(
  1120. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1121. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1122. FutureWarning,
  1123. )
  1124. self.freeze_feature_encoder()
  1125. def freeze_feature_encoder(self):
  1126. """
  1127. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1128. not be updated during training.
  1129. """
  1130. self.data2vec_audio.feature_extractor._freeze_parameters()
  1131. def freeze_base_model(self):
  1132. """
  1133. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1134. be updated during training. Only the classification head will be updated.
  1135. """
  1136. for param in self.data2vec_audio.parameters():
  1137. param.requires_grad = False
  1138. @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
  1139. @add_code_sample_docstrings(
  1140. checkpoint=_CHECKPOINT_FOR_DOC,
  1141. output_type=SequenceClassifierOutput,
  1142. config_class=_CONFIG_FOR_DOC,
  1143. modality="audio",
  1144. )
  1145. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with wav2vec2->data2vec_audio
  1146. def forward(
  1147. self,
  1148. input_values: Optional[torch.Tensor],
  1149. attention_mask: Optional[torch.Tensor] = None,
  1150. output_attentions: Optional[bool] = None,
  1151. output_hidden_states: Optional[bool] = None,
  1152. return_dict: Optional[bool] = None,
  1153. labels: Optional[torch.Tensor] = None,
  1154. ) -> Union[Tuple, SequenceClassifierOutput]:
  1155. r"""
  1156. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1157. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1158. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1159. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1160. """
  1161. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1162. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1163. outputs = self.data2vec_audio(
  1164. input_values,
  1165. attention_mask=attention_mask,
  1166. output_attentions=output_attentions,
  1167. output_hidden_states=output_hidden_states,
  1168. return_dict=return_dict,
  1169. )
  1170. if self.config.use_weighted_layer_sum:
  1171. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1172. hidden_states = torch.stack(hidden_states, dim=1)
  1173. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1174. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1175. else:
  1176. hidden_states = outputs[0]
  1177. hidden_states = self.projector(hidden_states)
  1178. if attention_mask is None:
  1179. pooled_output = hidden_states.mean(dim=1)
  1180. else:
  1181. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1182. hidden_states[~padding_mask] = 0.0
  1183. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1184. logits = self.classifier(pooled_output)
  1185. loss = None
  1186. if labels is not None:
  1187. loss_fct = CrossEntropyLoss()
  1188. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1189. if not return_dict:
  1190. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1191. return ((loss,) + output) if loss is not None else output
  1192. return SequenceClassifierOutput(
  1193. loss=loss,
  1194. logits=logits,
  1195. hidden_states=outputs.hidden_states,
  1196. attentions=outputs.attentions,
  1197. )
  1198. @add_start_docstrings(
  1199. """
  1200. Data2VecAudio Model with a frame classification head on top for tasks like Speaker Diarization.
  1201. """,
  1202. DATA2VEC_AUDIO_START_DOCSTRING,
  1203. )
  1204. class Data2VecAudioForAudioFrameClassification(Data2VecAudioPreTrainedModel):
  1205. def __init__(self, config):
  1206. super().__init__(config)
  1207. if hasattr(config, "add_adapter") and config.add_adapter:
  1208. raise ValueError(
  1209. "Audio frame classification does not support the use of Data2VecAudio adapters"
  1210. " (config.add_adapter=True)"
  1211. )
  1212. self.data2vec_audio = Data2VecAudioModel(config)
  1213. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1214. if config.use_weighted_layer_sum:
  1215. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1216. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1217. self.num_labels = config.num_labels
  1218. self.init_weights()
  1219. def freeze_feature_extractor(self):
  1220. """
  1221. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1222. not be updated during training.
  1223. """
  1224. warnings.warn(
  1225. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1226. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1227. FutureWarning,
  1228. )
  1229. self.freeze_feature_encoder()
  1230. def freeze_feature_encoder(self):
  1231. """
  1232. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1233. not be updated during training.
  1234. """
  1235. self.data2vec_audio.feature_extractor._freeze_parameters()
  1236. def freeze_base_model(self):
  1237. """
  1238. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1239. be updated during training. Only the classification head will be updated.
  1240. """
  1241. for param in self.data2vec_audio.parameters():
  1242. param.requires_grad = False
  1243. @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
  1244. @add_code_sample_docstrings(
  1245. checkpoint=_CHECKPOINT_FOR_DOC,
  1246. output_type=TokenClassifierOutput,
  1247. config_class=_CONFIG_FOR_DOC,
  1248. modality="audio",
  1249. )
  1250. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->data2vec_audio
  1251. def forward(
  1252. self,
  1253. input_values: Optional[torch.Tensor],
  1254. attention_mask: Optional[torch.Tensor] = None,
  1255. labels: Optional[torch.Tensor] = None,
  1256. output_attentions: Optional[bool] = None,
  1257. output_hidden_states: Optional[bool] = None,
  1258. return_dict: Optional[bool] = None,
  1259. ) -> Union[Tuple, TokenClassifierOutput]:
  1260. r"""
  1261. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1262. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1263. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1264. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1265. """
  1266. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1267. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1268. outputs = self.data2vec_audio(
  1269. input_values,
  1270. attention_mask=attention_mask,
  1271. output_attentions=output_attentions,
  1272. output_hidden_states=output_hidden_states,
  1273. return_dict=return_dict,
  1274. )
  1275. if self.config.use_weighted_layer_sum:
  1276. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1277. hidden_states = torch.stack(hidden_states, dim=1)
  1278. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1279. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1280. else:
  1281. hidden_states = outputs[0]
  1282. logits = self.classifier(hidden_states)
  1283. loss = None
  1284. if labels is not None:
  1285. loss_fct = CrossEntropyLoss()
  1286. loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
  1287. if not return_dict:
  1288. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1289. return output
  1290. return TokenClassifierOutput(
  1291. loss=loss,
  1292. logits=logits,
  1293. hidden_states=outputs.hidden_states,
  1294. attentions=outputs.attentions,
  1295. )
  1296. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss
  1297. class AMSoftmaxLoss(nn.Module):
  1298. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  1299. super(AMSoftmaxLoss, self).__init__()
  1300. self.scale = scale
  1301. self.margin = margin
  1302. self.num_labels = num_labels
  1303. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  1304. self.loss = nn.CrossEntropyLoss()
  1305. def forward(self, hidden_states, labels):
  1306. labels = labels.flatten()
  1307. weight = nn.functional.normalize(self.weight, dim=0)
  1308. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  1309. cos_theta = torch.mm(hidden_states, weight)
  1310. psi = cos_theta - self.margin
  1311. onehot = nn.functional.one_hot(labels, self.num_labels)
  1312. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  1313. loss = self.loss(logits, labels)
  1314. return loss
  1315. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer
  1316. class TDNNLayer(nn.Module):
  1317. def __init__(self, config, layer_id=0):
  1318. super().__init__()
  1319. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  1320. self.out_conv_dim = config.tdnn_dim[layer_id]
  1321. self.kernel_size = config.tdnn_kernel[layer_id]
  1322. self.dilation = config.tdnn_dilation[layer_id]
  1323. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  1324. self.activation = nn.ReLU()
  1325. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1326. if is_peft_available():
  1327. from peft.tuners.lora import LoraLayer
  1328. if isinstance(self.kernel, LoraLayer):
  1329. warnings.warn(
  1330. "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
  1331. "You should exclude TDNNLayer from LoRA's target modules.",
  1332. )
  1333. # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
  1334. hidden_states = hidden_states.transpose(1, 2)
  1335. weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
  1336. hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
  1337. hidden_states = hidden_states.transpose(1, 2)
  1338. hidden_states = self.activation(hidden_states)
  1339. return hidden_states
  1340. @add_start_docstrings(
  1341. """
  1342. Data2VecAudio Model with an XVector feature extraction head on top for tasks like Speaker Verification.
  1343. """,
  1344. DATA2VEC_AUDIO_START_DOCSTRING,
  1345. )
  1346. class Data2VecAudioForXVector(Data2VecAudioPreTrainedModel):
  1347. def __init__(self, config):
  1348. super().__init__(config)
  1349. self.data2vec_audio = Data2VecAudioModel(config)
  1350. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1351. if config.use_weighted_layer_sum:
  1352. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1353. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  1354. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  1355. self.tdnn = nn.ModuleList(tdnn_layers)
  1356. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  1357. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  1358. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  1359. self.init_weights()
  1360. def freeze_feature_extractor(self):
  1361. """
  1362. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1363. not be updated during training.
  1364. """
  1365. warnings.warn(
  1366. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1367. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1368. FutureWarning,
  1369. )
  1370. self.freeze_feature_encoder()
  1371. def freeze_feature_encoder(self):
  1372. """
  1373. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1374. not be updated during training.
  1375. """
  1376. self.data2vec_audio.feature_extractor._freeze_parameters()
  1377. def freeze_base_model(self):
  1378. """
  1379. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1380. be updated during training. Only the classification head will be updated.
  1381. """
  1382. for param in self.data2vec_audio.parameters():
  1383. param.requires_grad = False
  1384. def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  1385. """
  1386. Computes the output length of the TDNN layers
  1387. """
  1388. def _conv_out_length(input_length, kernel_size, stride):
  1389. # 1D convolutional layer output length formula taken
  1390. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1391. return (input_length - kernel_size) // stride + 1
  1392. for kernel_size in self.config.tdnn_kernel:
  1393. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  1394. return input_lengths
  1395. @add_start_docstrings_to_model_forward(DATA2VEC_AUDIO_INPUTS_DOCSTRING)
  1396. @add_code_sample_docstrings(
  1397. checkpoint=_CHECKPOINT_FOR_DOC,
  1398. output_type=XVectorOutput,
  1399. config_class=_CONFIG_FOR_DOC,
  1400. modality="audio",
  1401. )
  1402. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with wav2vec2->data2vec_audio
  1403. def forward(
  1404. self,
  1405. input_values: Optional[torch.Tensor],
  1406. attention_mask: Optional[torch.Tensor] = None,
  1407. output_attentions: Optional[bool] = None,
  1408. output_hidden_states: Optional[bool] = None,
  1409. return_dict: Optional[bool] = None,
  1410. labels: Optional[torch.Tensor] = None,
  1411. ) -> Union[Tuple, XVectorOutput]:
  1412. r"""
  1413. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1414. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1415. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1416. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1417. """
  1418. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1419. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1420. outputs = self.data2vec_audio(
  1421. input_values,
  1422. attention_mask=attention_mask,
  1423. output_attentions=output_attentions,
  1424. output_hidden_states=output_hidden_states,
  1425. return_dict=return_dict,
  1426. )
  1427. if self.config.use_weighted_layer_sum:
  1428. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1429. hidden_states = torch.stack(hidden_states, dim=1)
  1430. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1431. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1432. else:
  1433. hidden_states = outputs[0]
  1434. hidden_states = self.projector(hidden_states)
  1435. for tdnn_layer in self.tdnn:
  1436. hidden_states = tdnn_layer(hidden_states)
  1437. # Statistic Pooling
  1438. if attention_mask is None:
  1439. mean_features = hidden_states.mean(dim=1)
  1440. std_features = hidden_states.std(dim=1)
  1441. else:
  1442. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  1443. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  1444. mean_features = []
  1445. std_features = []
  1446. for i, length in enumerate(tdnn_output_lengths):
  1447. mean_features.append(hidden_states[i, :length].mean(dim=0))
  1448. std_features.append(hidden_states[i, :length].std(dim=0))
  1449. mean_features = torch.stack(mean_features)
  1450. std_features = torch.stack(std_features)
  1451. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  1452. output_embeddings = self.feature_extractor(statistic_pooling)
  1453. logits = self.classifier(output_embeddings)
  1454. loss = None
  1455. if labels is not None:
  1456. loss = self.objective(logits, labels)
  1457. if not return_dict:
  1458. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  1459. return ((loss,) + output) if loss is not None else output
  1460. return XVectorOutput(
  1461. loss=loss,
  1462. logits=logits,
  1463. embeddings=output_embeddings,
  1464. hidden_states=outputs.hidden_states,
  1465. attentions=outputs.attentions,
  1466. )