modeling_hubert.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Hubert model."""
  16. import warnings
  17. from typing import Optional, Tuple, Union
  18. import numpy as np
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  25. from ...integrations.fsdp import is_fsdp_managed_module
  26. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import (
  29. add_code_sample_docstrings,
  30. add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. is_flash_attn_2_available,
  33. is_flash_attn_greater_or_equal_2_10,
  34. logging,
  35. replace_return_docstrings,
  36. )
  37. from .configuration_hubert import HubertConfig
  38. if is_flash_attn_2_available():
  39. from ...modeling_flash_attention_utils import _flash_attention_forward
  40. logger = logging.get_logger(__name__)
  41. _HIDDEN_STATES_START_POSITION = 1
  42. # General docstring
  43. _CONFIG_FOR_DOC = "HubertConfig"
  44. # Base docstring
  45. _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
  46. _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
  47. # CTC docstring
  48. _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
  49. _CTC_EXPECTED_LOSS = 22.68
  50. # Audio class docstring
  51. _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
  52. _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
  53. _SEQ_CLASS_EXPECTED_LOSS = 8.53
  54. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  55. def _compute_mask_indices(
  56. shape: Tuple[int, int],
  57. mask_prob: float,
  58. mask_length: int,
  59. attention_mask: Optional[torch.LongTensor] = None,
  60. min_masks: int = 0,
  61. ) -> np.ndarray:
  62. """
  63. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  64. ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  65. CPU as part of the preprocessing during training.
  66. Args:
  67. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  68. the first element is the batch size and the second element is the length of the axis to span.
  69. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  70. independently generated mask spans of length `mask_length` is computed by
  71. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  72. actual percentage will be smaller.
  73. mask_length: size of the mask
  74. min_masks: minimum number of masked spans
  75. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  76. each batch dimension.
  77. """
  78. batch_size, sequence_length = shape
  79. if mask_length < 1:
  80. raise ValueError("`mask_length` has to be bigger than 0.")
  81. if mask_length > sequence_length:
  82. raise ValueError(
  83. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  84. f" and `sequence_length`: {sequence_length}`"
  85. )
  86. # epsilon is used for probabilistic rounding
  87. epsilon = np.random.rand(1).item()
  88. def compute_num_masked_span(input_length):
  89. """Given input length, compute how many spans should be masked"""
  90. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  91. num_masked_span = max(num_masked_span, min_masks)
  92. # make sure num masked span <= sequence_length
  93. if num_masked_span * mask_length > sequence_length:
  94. num_masked_span = sequence_length // mask_length
  95. # make sure num_masked span is also <= input_length - (mask_length - 1)
  96. if input_length - (mask_length - 1) < num_masked_span:
  97. num_masked_span = max(input_length - (mask_length - 1), 0)
  98. return num_masked_span
  99. # compute number of masked spans in batch
  100. input_lengths = (
  101. attention_mask.sum(-1).detach().tolist()
  102. if attention_mask is not None
  103. else [sequence_length for _ in range(batch_size)]
  104. )
  105. # SpecAugment mask to fill
  106. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  107. spec_aug_mask_idxs = []
  108. max_num_masked_span = compute_num_masked_span(sequence_length)
  109. if max_num_masked_span == 0:
  110. return spec_aug_mask
  111. for input_length in input_lengths:
  112. # compute num of masked spans for this input
  113. num_masked_span = compute_num_masked_span(input_length)
  114. # get random indices to mask
  115. spec_aug_mask_idx = np.random.choice(
  116. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  117. )
  118. # pick first sampled index that will serve as a dummy index to pad vector
  119. # to ensure same dimension for all batches due to probabilistic rounding
  120. # Picking first sample just pads those vectors twice.
  121. if len(spec_aug_mask_idx) == 0:
  122. # this case can only happen if `input_length` is strictly smaller then
  123. # `sequence_length` in which case the last token has to be a padding
  124. # token which we can use as a dummy mask id
  125. dummy_mask_idx = sequence_length - 1
  126. else:
  127. dummy_mask_idx = spec_aug_mask_idx[0]
  128. spec_aug_mask_idx = np.concatenate(
  129. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  130. )
  131. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  132. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  133. # expand masked indices to masked spans
  134. spec_aug_mask_idxs = np.broadcast_to(
  135. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  136. )
  137. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  138. # add offset to the starting indexes so that indexes now create a span
  139. offsets = np.arange(mask_length)[None, None, :]
  140. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  141. batch_size, max_num_masked_span * mask_length
  142. )
  143. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  144. # ensure that we cannot have indices larger than sequence_length
  145. if spec_aug_mask_idxs.max() > sequence_length - 1:
  146. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  147. # scatter indices to mask
  148. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  149. return spec_aug_mask
  150. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert
  151. class HubertNoLayerNormConvLayer(nn.Module):
  152. def __init__(self, config, layer_id=0):
  153. super().__init__()
  154. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  155. self.out_conv_dim = config.conv_dim[layer_id]
  156. self.conv = nn.Conv1d(
  157. self.in_conv_dim,
  158. self.out_conv_dim,
  159. kernel_size=config.conv_kernel[layer_id],
  160. stride=config.conv_stride[layer_id],
  161. bias=config.conv_bias,
  162. )
  163. self.activation = ACT2FN[config.feat_extract_activation]
  164. def forward(self, hidden_states):
  165. hidden_states = self.conv(hidden_states)
  166. hidden_states = self.activation(hidden_states)
  167. return hidden_states
  168. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
  169. class HubertLayerNormConvLayer(nn.Module):
  170. def __init__(self, config, layer_id=0):
  171. super().__init__()
  172. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  173. self.out_conv_dim = config.conv_dim[layer_id]
  174. self.conv = nn.Conv1d(
  175. self.in_conv_dim,
  176. self.out_conv_dim,
  177. kernel_size=config.conv_kernel[layer_id],
  178. stride=config.conv_stride[layer_id],
  179. bias=config.conv_bias,
  180. )
  181. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  182. self.activation = ACT2FN[config.feat_extract_activation]
  183. def forward(self, hidden_states):
  184. hidden_states = self.conv(hidden_states)
  185. hidden_states = hidden_states.transpose(-2, -1)
  186. hidden_states = self.layer_norm(hidden_states)
  187. hidden_states = hidden_states.transpose(-2, -1)
  188. hidden_states = self.activation(hidden_states)
  189. return hidden_states
  190. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
  191. class HubertGroupNormConvLayer(nn.Module):
  192. def __init__(self, config, layer_id=0):
  193. super().__init__()
  194. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  195. self.out_conv_dim = config.conv_dim[layer_id]
  196. self.conv = nn.Conv1d(
  197. self.in_conv_dim,
  198. self.out_conv_dim,
  199. kernel_size=config.conv_kernel[layer_id],
  200. stride=config.conv_stride[layer_id],
  201. bias=config.conv_bias,
  202. )
  203. self.activation = ACT2FN[config.feat_extract_activation]
  204. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  205. def forward(self, hidden_states):
  206. hidden_states = self.conv(hidden_states)
  207. hidden_states = self.layer_norm(hidden_states)
  208. hidden_states = self.activation(hidden_states)
  209. return hidden_states
  210. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
  211. class HubertPositionalConvEmbedding(nn.Module):
  212. def __init__(self, config):
  213. super().__init__()
  214. self.conv = nn.Conv1d(
  215. config.hidden_size,
  216. config.hidden_size,
  217. kernel_size=config.num_conv_pos_embeddings,
  218. padding=config.num_conv_pos_embeddings // 2,
  219. groups=config.num_conv_pos_embedding_groups,
  220. )
  221. weight_norm = nn.utils.weight_norm
  222. if hasattr(nn.utils.parametrizations, "weight_norm"):
  223. weight_norm = nn.utils.parametrizations.weight_norm
  224. if is_deepspeed_zero3_enabled():
  225. import deepspeed
  226. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  227. self.conv = weight_norm(self.conv, name="weight", dim=2)
  228. if hasattr(self.conv, "parametrizations"):
  229. weight_g = self.conv.parametrizations.weight.original0
  230. weight_v = self.conv.parametrizations.weight.original1
  231. else:
  232. weight_g = self.conv.weight_g
  233. weight_v = self.conv.weight_v
  234. deepspeed.zero.register_external_parameter(self, weight_v)
  235. deepspeed.zero.register_external_parameter(self, weight_g)
  236. else:
  237. self.conv = weight_norm(self.conv, name="weight", dim=2)
  238. self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
  239. self.activation = ACT2FN[config.feat_extract_activation]
  240. def forward(self, hidden_states):
  241. hidden_states = hidden_states.transpose(1, 2)
  242. hidden_states = self.conv(hidden_states)
  243. hidden_states = self.padding(hidden_states)
  244. hidden_states = self.activation(hidden_states)
  245. hidden_states = hidden_states.transpose(1, 2)
  246. return hidden_states
  247. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->Hubert
  248. class HubertSamePadLayer(nn.Module):
  249. def __init__(self, num_conv_pos_embeddings):
  250. super().__init__()
  251. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  252. def forward(self, hidden_states):
  253. if self.num_pad_remove > 0:
  254. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  255. return hidden_states
  256. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Hubert
  257. class HubertFeatureEncoder(nn.Module):
  258. """Construct the features from raw audio waveform"""
  259. def __init__(self, config):
  260. super().__init__()
  261. if config.feat_extract_norm == "group":
  262. conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [
  263. HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  264. ]
  265. elif config.feat_extract_norm == "layer":
  266. conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  267. else:
  268. raise ValueError(
  269. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  270. )
  271. self.conv_layers = nn.ModuleList(conv_layers)
  272. self.gradient_checkpointing = False
  273. self._requires_grad = True
  274. def _freeze_parameters(self):
  275. for param in self.parameters():
  276. param.requires_grad = False
  277. self._requires_grad = False
  278. def forward(self, input_values):
  279. hidden_states = input_values[:, None]
  280. # make sure hidden_states require grad for gradient_checkpointing
  281. if self._requires_grad and self.training:
  282. hidden_states.requires_grad = True
  283. for conv_layer in self.conv_layers:
  284. if self._requires_grad and self.gradient_checkpointing and self.training:
  285. hidden_states = self._gradient_checkpointing_func(
  286. conv_layer.__call__,
  287. hidden_states,
  288. )
  289. else:
  290. hidden_states = conv_layer(hidden_states)
  291. return hidden_states
  292. class HubertFeatureExtractor(HubertFeatureEncoder):
  293. def __init__(self, config):
  294. super().__init__(config)
  295. warnings.warn(
  296. f"The class `{self.__class__.__name__}` has been depreciated "
  297. "and will be removed in Transformers v5. "
  298. f"Use `{self.__class__.__bases__[0].__name__}` instead.",
  299. FutureWarning,
  300. )
  301. class HubertFeatureProjection(nn.Module):
  302. def __init__(self, config):
  303. super().__init__()
  304. self.feat_proj_layer_norm = config.feat_proj_layer_norm
  305. if self.feat_proj_layer_norm:
  306. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  307. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  308. self.dropout = nn.Dropout(config.feat_proj_dropout)
  309. def forward(self, hidden_states):
  310. # non-projected hidden states are needed for quantization
  311. if self.feat_proj_layer_norm:
  312. hidden_states = self.layer_norm(hidden_states)
  313. hidden_states = self.projection(hidden_states)
  314. hidden_states = self.dropout(hidden_states)
  315. return hidden_states
  316. # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Hubert
  317. class HubertAttention(nn.Module):
  318. """Multi-headed attention from 'Attention Is All You Need' paper"""
  319. def __init__(
  320. self,
  321. embed_dim: int,
  322. num_heads: int,
  323. dropout: float = 0.0,
  324. is_decoder: bool = False,
  325. bias: bool = True,
  326. is_causal: bool = False,
  327. config: Optional[HubertConfig] = None,
  328. ):
  329. super().__init__()
  330. self.embed_dim = embed_dim
  331. self.num_heads = num_heads
  332. self.dropout = dropout
  333. self.head_dim = embed_dim // num_heads
  334. self.config = config
  335. if (self.head_dim * num_heads) != self.embed_dim:
  336. raise ValueError(
  337. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  338. f" and `num_heads`: {num_heads})."
  339. )
  340. self.scaling = self.head_dim**-0.5
  341. self.is_decoder = is_decoder
  342. self.is_causal = is_causal
  343. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  344. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  345. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  346. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  347. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  348. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  349. def forward(
  350. self,
  351. hidden_states: torch.Tensor,
  352. key_value_states: Optional[torch.Tensor] = None,
  353. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  354. attention_mask: Optional[torch.Tensor] = None,
  355. layer_head_mask: Optional[torch.Tensor] = None,
  356. output_attentions: bool = False,
  357. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  358. """Input shape: Batch x Time x Channel"""
  359. # if key_value_states are provided this layer is used as a cross-attention layer
  360. # for the decoder
  361. is_cross_attention = key_value_states is not None
  362. bsz, tgt_len, _ = hidden_states.size()
  363. # get query proj
  364. query_states = self.q_proj(hidden_states) * self.scaling
  365. # get key, value proj
  366. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  367. # is checking that the `sequence_length` of the `past_key_value` is the same as
  368. # the provided `key_value_states` to support prefix tuning
  369. if (
  370. is_cross_attention
  371. and past_key_value is not None
  372. and past_key_value[0].shape[2] == key_value_states.shape[1]
  373. ):
  374. # reuse k,v, cross_attentions
  375. key_states = past_key_value[0]
  376. value_states = past_key_value[1]
  377. elif is_cross_attention:
  378. # cross_attentions
  379. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  380. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  381. elif past_key_value is not None:
  382. # reuse k, v, self_attention
  383. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  384. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  385. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  386. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  387. else:
  388. # self_attention
  389. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  390. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  391. if self.is_decoder:
  392. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  393. # Further calls to cross_attention layer can then reuse all cross-attention
  394. # key/value_states (first "if" case)
  395. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  396. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  397. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  398. # if encoder bi-directional self-attention `past_key_value` is always `None`
  399. past_key_value = (key_states, value_states)
  400. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  401. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  402. key_states = key_states.reshape(*proj_shape)
  403. value_states = value_states.reshape(*proj_shape)
  404. src_len = key_states.size(1)
  405. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  406. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  407. raise ValueError(
  408. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  409. f" {attn_weights.size()}"
  410. )
  411. if attention_mask is not None:
  412. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  413. raise ValueError(
  414. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  415. )
  416. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  417. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  418. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  419. if layer_head_mask is not None:
  420. if layer_head_mask.size() != (self.num_heads,):
  421. raise ValueError(
  422. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  423. f" {layer_head_mask.size()}"
  424. )
  425. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  426. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  427. if output_attentions:
  428. # this operation is a bit awkward, but it's required to
  429. # make sure that attn_weights keeps its gradient.
  430. # In order to do so, attn_weights have to be reshaped
  431. # twice and have to be reused in the following
  432. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  433. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  434. else:
  435. attn_weights_reshaped = None
  436. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  437. attn_output = torch.bmm(attn_probs, value_states)
  438. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  439. raise ValueError(
  440. f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
  441. f" {attn_output.size()}"
  442. )
  443. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  444. attn_output = attn_output.transpose(1, 2)
  445. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  446. # partitioned across GPUs when using tensor-parallelism.
  447. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  448. attn_output = self.out_proj(attn_output)
  449. return attn_output, attn_weights_reshaped, past_key_value
  450. # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert
  451. class HubertFlashAttention2(HubertAttention):
  452. """
  453. Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
  454. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  455. flash attention and deal with padding tokens in case the input contains any of them.
  456. """
  457. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  458. def __init__(self, *args, **kwargs):
  459. super().__init__(*args, **kwargs)
  460. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  461. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  462. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  463. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  464. def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  465. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
  466. def forward(
  467. self,
  468. hidden_states: torch.Tensor,
  469. key_value_states: Optional[torch.Tensor] = None,
  470. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  471. attention_mask: Optional[torch.Tensor] = None,
  472. layer_head_mask: Optional[torch.Tensor] = None,
  473. output_attentions: bool = False,
  474. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  475. # HubertFlashAttention2 attention does not support output_attentions
  476. if output_attentions:
  477. raise ValueError("HubertFlashAttention2 attention does not support output_attentions")
  478. # if key_value_states are provided this layer is used as a cross-attention layer
  479. # for the decoder
  480. is_cross_attention = key_value_states is not None
  481. bsz, q_len, _ = hidden_states.size()
  482. # get query proj
  483. query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
  484. # get key, value proj
  485. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  486. # is checking that the `sequence_length` of the `past_key_value` is the same as
  487. # the provided `key_value_states` to support prefix tuning
  488. if (
  489. is_cross_attention
  490. and past_key_value is not None
  491. and past_key_value[0].shape[2] == key_value_states.shape[1]
  492. ):
  493. # reuse k,v, cross_attentions
  494. key_states = past_key_value[0].transpose(1, 2)
  495. value_states = past_key_value[1].transpose(1, 2)
  496. elif is_cross_attention:
  497. # cross_attentions
  498. key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
  499. value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
  500. elif past_key_value is not None:
  501. # reuse k, v, self_attention
  502. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  503. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  504. key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
  505. value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
  506. else:
  507. # self_attention
  508. key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
  509. value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
  510. if self.is_decoder:
  511. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  512. # Further calls to cross_attention layer can then reuse all cross-attention
  513. # key/value_states (first "if" case)
  514. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  515. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  516. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  517. # if encoder bi-directional self-attention `past_key_value` is always `None`
  518. past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
  519. kv_seq_len = key_states.shape[-2]
  520. if past_key_value is not None:
  521. kv_seq_len += past_key_value[0].shape[-2]
  522. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  523. # therefore the input hidden states gets silently casted in float32. Hence, we need
  524. # cast them back in the correct dtype just to be sure everything works as expected.
  525. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  526. # in fp32. (LlamaRMSNorm handles it correctly)
  527. input_dtype = query_states.dtype
  528. if input_dtype == torch.float32:
  529. if torch.is_autocast_enabled():
  530. target_dtype = torch.get_autocast_gpu_dtype()
  531. # Handle the case where the model is quantized
  532. elif hasattr(self.config, "_pre_quantization_dtype"):
  533. target_dtype = self.config._pre_quantization_dtype
  534. else:
  535. target_dtype = self.q_proj.weight.dtype
  536. logger.warning_once(
  537. f"The input hidden states seems to be silently casted in float32, this might be related to"
  538. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  539. f" {target_dtype}."
  540. )
  541. query_states = query_states.to(target_dtype)
  542. key_states = key_states.to(target_dtype)
  543. value_states = value_states.to(target_dtype)
  544. attn_output = _flash_attention_forward(
  545. query_states,
  546. key_states,
  547. value_states,
  548. attention_mask,
  549. q_len,
  550. dropout=self.dropout if self.training else 0.0,
  551. is_causal=self.is_causal,
  552. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  553. )
  554. attn_output = attn_output.reshape(bsz, q_len, -1)
  555. attn_output = self.out_proj(attn_output)
  556. if not output_attentions:
  557. attn_weights = None
  558. return attn_output, attn_weights, past_key_value
  559. class HubertSdpaAttention(HubertAttention):
  560. # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert
  561. def forward(
  562. self,
  563. hidden_states: torch.Tensor,
  564. key_value_states: Optional[torch.Tensor] = None,
  565. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  566. attention_mask: Optional[torch.Tensor] = None,
  567. layer_head_mask: Optional[torch.Tensor] = None,
  568. output_attentions: bool = False,
  569. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  570. """Input shape: Batch x Time x Channel"""
  571. if output_attentions or layer_head_mask is not None:
  572. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
  573. logger.warning_once(
  574. "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention"
  575. ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  576. )
  577. return super().forward(
  578. hidden_states,
  579. key_value_states=key_value_states,
  580. past_key_value=past_key_value,
  581. attention_mask=attention_mask,
  582. layer_head_mask=layer_head_mask,
  583. output_attentions=output_attentions,
  584. )
  585. # if key_value_states are provided this layer is used as a cross-attention layer
  586. # for the decoder
  587. is_cross_attention = key_value_states is not None
  588. bsz, tgt_len, _ = hidden_states.size()
  589. # get query proj
  590. query_states = self.q_proj(hidden_states)
  591. # get key, value proj
  592. # `past_key_value[0].shape[2] == key_value_states.shape[1]`
  593. # is checking that the `sequence_length` of the `past_key_value` is the same as
  594. # the provided `key_value_states` to support prefix tuning
  595. if (
  596. is_cross_attention
  597. and past_key_value is not None
  598. and past_key_value[0].shape[2] == key_value_states.shape[1]
  599. ):
  600. # reuse k,v, cross_attentions
  601. key_states = past_key_value[0]
  602. value_states = past_key_value[1]
  603. elif is_cross_attention:
  604. # cross_attentions
  605. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  606. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  607. elif past_key_value is not None:
  608. # reuse k, v, self_attention
  609. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  610. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  611. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  612. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  613. else:
  614. # self_attention
  615. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  616. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  617. if self.is_decoder:
  618. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  619. # Further calls to cross_attention layer can then reuse all cross-attention
  620. # key/value_states (first "if" case)
  621. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  622. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  623. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  624. # if encoder bi-directional self-attention `past_key_value` is always `None`
  625. past_key_value = (key_states, value_states)
  626. query_states = self._shape(query_states, tgt_len, bsz)
  627. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  628. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  629. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
  630. is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
  631. # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
  632. # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
  633. attn_output = torch.nn.functional.scaled_dot_product_attention(
  634. query_states,
  635. key_states,
  636. value_states,
  637. attn_mask=attention_mask,
  638. dropout_p=self.dropout if self.training else 0.0,
  639. is_causal=is_causal,
  640. )
  641. if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
  642. raise ValueError(
  643. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  644. f" {attn_output.size()}"
  645. )
  646. attn_output = attn_output.transpose(1, 2)
  647. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  648. # partitioned across GPUs when using tensor-parallelism.
  649. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  650. attn_output = self.out_proj(attn_output)
  651. return attn_output, None, past_key_value
  652. HUBERT_ATTENTION_CLASSES = {
  653. "eager": HubertAttention,
  654. "sdpa": HubertSdpaAttention,
  655. "flash_attention_2": HubertFlashAttention2,
  656. }
  657. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Hubert
  658. class HubertFeedForward(nn.Module):
  659. def __init__(self, config):
  660. super().__init__()
  661. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  662. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  663. if isinstance(config.hidden_act, str):
  664. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  665. else:
  666. self.intermediate_act_fn = config.hidden_act
  667. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  668. self.output_dropout = nn.Dropout(config.hidden_dropout)
  669. def forward(self, hidden_states):
  670. hidden_states = self.intermediate_dense(hidden_states)
  671. hidden_states = self.intermediate_act_fn(hidden_states)
  672. hidden_states = self.intermediate_dropout(hidden_states)
  673. hidden_states = self.output_dense(hidden_states)
  674. hidden_states = self.output_dropout(hidden_states)
  675. return hidden_states
  676. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
  677. class HubertEncoderLayer(nn.Module):
  678. def __init__(self, config):
  679. super().__init__()
  680. self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
  681. embed_dim=config.hidden_size,
  682. num_heads=config.num_attention_heads,
  683. dropout=config.attention_dropout,
  684. is_decoder=False,
  685. )
  686. self.dropout = nn.Dropout(config.hidden_dropout)
  687. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  688. self.feed_forward = HubertFeedForward(config)
  689. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  690. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  691. attn_residual = hidden_states
  692. hidden_states, attn_weights, _ = self.attention(
  693. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  694. )
  695. hidden_states = self.dropout(hidden_states)
  696. hidden_states = attn_residual + hidden_states
  697. hidden_states = self.layer_norm(hidden_states)
  698. hidden_states = hidden_states + self.feed_forward(hidden_states)
  699. hidden_states = self.final_layer_norm(hidden_states)
  700. outputs = (hidden_states,)
  701. if output_attentions:
  702. outputs += (attn_weights,)
  703. return outputs
  704. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AttnAdapterLayer with Wav2Vec2->Hubert
  705. class HubertAttnAdapterLayer(nn.Module):
  706. def __init__(self, config):
  707. """
  708. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  709. up training throughput.
  710. """
  711. super().__init__()
  712. self.input_dim = config.adapter_attn_dim
  713. self.hidden_dim = config.hidden_size
  714. self.norm = nn.LayerNorm(self.hidden_dim)
  715. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  716. self.act_fn = nn.ReLU()
  717. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  718. def forward(self, hidden_states: torch.FloatTensor):
  719. hidden_states = self.norm(hidden_states)
  720. hidden_states = self.linear_1(hidden_states)
  721. hidden_states = self.act_fn(hidden_states)
  722. hidden_states = self.linear_2(hidden_states)
  723. return hidden_states
  724. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert, WAV2VEC2->HUBERT
  725. class HubertEncoderLayerStableLayerNorm(nn.Module):
  726. def __init__(self, config):
  727. super().__init__()
  728. self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
  729. embed_dim=config.hidden_size,
  730. num_heads=config.num_attention_heads,
  731. dropout=config.attention_dropout,
  732. is_decoder=False,
  733. )
  734. self.dropout = nn.Dropout(config.hidden_dropout)
  735. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  736. self.feed_forward = HubertFeedForward(config)
  737. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  738. if getattr(config, "adapter_attn_dim", None) is not None:
  739. self.adapter_layer = HubertAttnAdapterLayer(config)
  740. else:
  741. self.adapter_layer = None
  742. def forward(
  743. self,
  744. hidden_states: torch.Tensor,
  745. attention_mask: Optional[torch.Tensor] = None,
  746. output_attentions: bool = False,
  747. ):
  748. attn_residual = hidden_states
  749. hidden_states = self.layer_norm(hidden_states)
  750. hidden_states, attn_weights, _ = self.attention(
  751. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  752. )
  753. hidden_states = self.dropout(hidden_states)
  754. hidden_states = attn_residual + hidden_states
  755. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  756. if self.adapter_layer is not None:
  757. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  758. outputs = (hidden_states,)
  759. if output_attentions:
  760. outputs += (attn_weights,)
  761. return outputs
  762. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Encoder with Wav2Vec2->Hubert
  763. class HubertEncoder(nn.Module):
  764. def __init__(self, config):
  765. super().__init__()
  766. self.config = config
  767. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  768. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  769. self.dropout = nn.Dropout(config.hidden_dropout)
  770. self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  771. self.gradient_checkpointing = False
  772. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  773. def forward(
  774. self,
  775. hidden_states: torch.tensor,
  776. attention_mask: Optional[torch.Tensor] = None,
  777. output_attentions: bool = False,
  778. output_hidden_states: bool = False,
  779. return_dict: bool = True,
  780. ):
  781. all_hidden_states = () if output_hidden_states else None
  782. all_self_attentions = () if output_attentions else None
  783. if attention_mask is not None:
  784. # make sure padded tokens output 0
  785. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  786. hidden_states[~expand_attention_mask] = 0
  787. if self._use_flash_attention_2:
  788. # 2d mask is passed through the layers
  789. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  790. else:
  791. # extend attention_mask
  792. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  793. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  794. attention_mask = attention_mask.expand(
  795. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  796. )
  797. position_embeddings = self.pos_conv_embed(hidden_states)
  798. hidden_states = hidden_states + position_embeddings
  799. hidden_states = self.layer_norm(hidden_states)
  800. hidden_states = self.dropout(hidden_states)
  801. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  802. for layer in self.layers:
  803. if output_hidden_states:
  804. all_hidden_states = all_hidden_states + (hidden_states,)
  805. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  806. dropout_probability = torch.rand([])
  807. skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
  808. if not skip_the_layer or synced_gpus:
  809. # under fsdp or deepspeed zero3 all gpus must run in sync
  810. if self.gradient_checkpointing and self.training:
  811. layer_outputs = self._gradient_checkpointing_func(
  812. layer.__call__,
  813. hidden_states,
  814. attention_mask,
  815. output_attentions,
  816. )
  817. else:
  818. layer_outputs = layer(
  819. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  820. )
  821. hidden_states = layer_outputs[0]
  822. if skip_the_layer:
  823. layer_outputs = (None, None)
  824. if output_attentions:
  825. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  826. if output_hidden_states:
  827. all_hidden_states = all_hidden_states + (hidden_states,)
  828. if not return_dict:
  829. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  830. return BaseModelOutput(
  831. last_hidden_state=hidden_states,
  832. hidden_states=all_hidden_states,
  833. attentions=all_self_attentions,
  834. )
  835. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
  836. class HubertEncoderStableLayerNorm(nn.Module):
  837. def __init__(self, config):
  838. super().__init__()
  839. self.config = config
  840. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  841. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  842. self.dropout = nn.Dropout(config.hidden_dropout)
  843. self.layers = nn.ModuleList(
  844. [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  845. )
  846. self.gradient_checkpointing = False
  847. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  848. def forward(
  849. self,
  850. hidden_states,
  851. attention_mask=None,
  852. output_attentions=False,
  853. output_hidden_states=False,
  854. return_dict=True,
  855. ):
  856. all_hidden_states = () if output_hidden_states else None
  857. all_self_attentions = () if output_attentions else None
  858. if attention_mask is not None:
  859. # make sure padded tokens are not attended to
  860. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  861. hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype)
  862. if self._use_flash_attention_2:
  863. # 2d mask is passed through the layers
  864. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  865. else:
  866. # extend attention_mask
  867. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  868. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  869. attention_mask = attention_mask.expand(
  870. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  871. )
  872. position_embeddings = self.pos_conv_embed(hidden_states)
  873. hidden_states = hidden_states + position_embeddings
  874. hidden_states = self.dropout(hidden_states)
  875. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  876. for layer in self.layers:
  877. if output_hidden_states:
  878. all_hidden_states = all_hidden_states + (hidden_states,)
  879. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  880. dropout_probability = torch.rand([])
  881. skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
  882. if not skip_the_layer or synced_gpus:
  883. # under fsdp or deepspeed zero3 all gpus must run in sync
  884. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  885. if self.gradient_checkpointing and self.training:
  886. layer_outputs = self._gradient_checkpointing_func(
  887. layer.__call__,
  888. hidden_states,
  889. attention_mask,
  890. output_attentions,
  891. )
  892. else:
  893. layer_outputs = layer(
  894. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  895. )
  896. hidden_states = layer_outputs[0]
  897. if skip_the_layer:
  898. layer_outputs = (None, None)
  899. if output_attentions:
  900. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  901. hidden_states = self.layer_norm(hidden_states)
  902. if output_hidden_states:
  903. all_hidden_states = all_hidden_states + (hidden_states,)
  904. if not return_dict:
  905. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  906. return BaseModelOutput(
  907. last_hidden_state=hidden_states,
  908. hidden_states=all_hidden_states,
  909. attentions=all_self_attentions,
  910. )
  911. class HubertPreTrainedModel(PreTrainedModel):
  912. """
  913. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  914. models.
  915. """
  916. config_class = HubertConfig
  917. base_model_prefix = "hubert"
  918. main_input_name = "input_values"
  919. supports_gradient_checkpointing = True
  920. _supports_flash_attn_2 = True
  921. _supports_sdpa = True
  922. def _init_weights(self, module):
  923. """Initialize the weights"""
  924. if isinstance(module, nn.Linear):
  925. # Slightly different from the TF version which uses truncated_normal for initialization
  926. # cf https://github.com/pytorch/pytorch/pull/5617
  927. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  928. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  929. module.bias.data.zero_()
  930. module.weight.data.fill_(1.0)
  931. elif isinstance(module, nn.Conv1d):
  932. if is_deepspeed_zero3_enabled():
  933. import deepspeed
  934. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  935. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  936. nn.init.kaiming_normal_(module.weight.data)
  937. else:
  938. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  939. nn.init.kaiming_normal_(module.weight.data)
  940. else:
  941. nn.init.kaiming_normal_(module.weight.data)
  942. if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
  943. module.bias.data.zero_()
  944. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  945. """
  946. Computes the output length of the convolutional layers
  947. """
  948. def _conv_out_length(input_length, kernel_size, stride):
  949. # 1D convolutional layer output length formula taken
  950. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  951. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  952. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  953. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  954. return input_lengths
  955. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  956. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  957. batch_size = attention_mask.shape[0]
  958. attention_mask = torch.zeros(
  959. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  960. )
  961. # these two operations makes sure that all values before the output lengths idxs are attended to
  962. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  963. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  964. return attention_mask
  965. HUBERT_START_DOCSTRING = r"""
  966. Hubert was proposed in [HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden
  967. Units](https://arxiv.org/abs/2106.07447) by Wei-Ning Hsu, Benjamin Bolte, Yao-Hung Hubert Tsai, Kushal Lakhotia,
  968. Ruslan Salakhutdinov, Abdelrahman Mohamed.
  969. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  970. library implements for all its model (such as downloading or saving etc.).
  971. This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
  972. it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
  973. behavior.
  974. Parameters:
  975. config ([`HubertConfig`]): Model configuration class with all the parameters of the model.
  976. Initializing with a config file does not load the weights associated with the model, only the
  977. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  978. """
  979. HUBERT_INPUTS_DOCSTRING = r"""
  980. Args:
  981. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  982. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  983. into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install
  984. soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and
  985. conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  986. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  987. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
  988. 1]`:
  989. - 1 for tokens that are **not masked**,
  990. - 0 for tokens that are **masked**.
  991. [What are attention masks?](../glossary#attention-mask)
  992. <Tip warning={true}>
  993. `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask ==
  994. True`. For all models whose processor has `config.return_attention_mask == False`, such as
  995. [hubert-base](https://huggingface.co/facebook/hubert-base-ls960), `attention_mask` should **not** be passed
  996. to avoid degraded performance when doing batched inference. For such models `input_values` should simply be
  997. padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly different
  998. results depending on whether `input_values` is padded or not.
  999. </Tip>
  1000. output_attentions (`bool`, *optional*):
  1001. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1002. tensors for more detail.
  1003. output_hidden_states (`bool`, *optional*):
  1004. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1005. more detail.
  1006. return_dict (`bool`, *optional*):
  1007. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1008. """
  1009. @add_start_docstrings(
  1010. "The bare Hubert Model transformer outputting raw hidden-states without any specific head on top.",
  1011. HUBERT_START_DOCSTRING,
  1012. )
  1013. class HubertModel(HubertPreTrainedModel):
  1014. def __init__(self, config: HubertConfig):
  1015. super().__init__(config)
  1016. self.config = config
  1017. self.feature_extractor = HubertFeatureEncoder(config)
  1018. self.feature_projection = HubertFeatureProjection(config)
  1019. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  1020. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  1021. if config.do_stable_layer_norm:
  1022. self.encoder = HubertEncoderStableLayerNorm(config)
  1023. else:
  1024. self.encoder = HubertEncoder(config)
  1025. # Initialize weights and apply final processing
  1026. self.post_init()
  1027. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  1028. def _mask_hidden_states(
  1029. self,
  1030. hidden_states: torch.FloatTensor,
  1031. mask_time_indices: Optional[torch.FloatTensor] = None,
  1032. attention_mask: Optional[torch.LongTensor] = None,
  1033. ):
  1034. """
  1035. Masks extracted features along time axis and/or along feature axis according to
  1036. [SpecAugment](https://arxiv.org/abs/1904.08779).
  1037. """
  1038. # `config.apply_spec_augment` can set masking to False
  1039. if not getattr(self.config, "apply_spec_augment", True):
  1040. return hidden_states
  1041. # generate indices & apply SpecAugment along time axis
  1042. batch_size, sequence_length, hidden_size = hidden_states.size()
  1043. if mask_time_indices is not None:
  1044. # apply SpecAugment along time axis with given mask_time_indices
  1045. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1046. elif self.config.mask_time_prob > 0 and self.training:
  1047. mask_time_indices = _compute_mask_indices(
  1048. (batch_size, sequence_length),
  1049. mask_prob=self.config.mask_time_prob,
  1050. mask_length=self.config.mask_time_length,
  1051. attention_mask=attention_mask,
  1052. min_masks=self.config.mask_time_min_masks,
  1053. )
  1054. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  1055. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1056. if self.config.mask_feature_prob > 0 and self.training:
  1057. # generate indices & apply SpecAugment along feature axis
  1058. mask_feature_indices = _compute_mask_indices(
  1059. (batch_size, hidden_size),
  1060. mask_prob=self.config.mask_feature_prob,
  1061. mask_length=self.config.mask_feature_length,
  1062. min_masks=self.config.mask_feature_min_masks,
  1063. )
  1064. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  1065. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  1066. hidden_states[mask_feature_indices] = 0
  1067. return hidden_states
  1068. @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
  1069. @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
  1070. def forward(
  1071. self,
  1072. input_values: Optional[torch.Tensor],
  1073. attention_mask: Optional[torch.Tensor] = None,
  1074. mask_time_indices: Optional[torch.FloatTensor] = None,
  1075. output_attentions: Optional[bool] = None,
  1076. output_hidden_states: Optional[bool] = None,
  1077. return_dict: Optional[bool] = None,
  1078. ) -> Union[Tuple, BaseModelOutput]:
  1079. """
  1080. Returns:
  1081. Example:
  1082. ```python
  1083. >>> from transformers import AutoProcessor, HubertModel
  1084. >>> from datasets import load_dataset
  1085. >>> import soundfile as sf
  1086. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  1087. >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  1088. >>> def map_to_array(batch):
  1089. ... speech, _ = sf.read(batch["file"])
  1090. ... batch["speech"] = speech
  1091. ... return batch
  1092. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1093. >>> ds = ds.map(map_to_array)
  1094. >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
  1095. >>> hidden_states = model(input_values).last_hidden_state
  1096. ```"""
  1097. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1098. output_hidden_states = (
  1099. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1100. )
  1101. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1102. extract_features = self.feature_extractor(input_values)
  1103. extract_features = extract_features.transpose(1, 2)
  1104. if attention_mask is not None:
  1105. # compute reduced attention_mask corresponding to feature vectors
  1106. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  1107. hidden_states = self.feature_projection(extract_features)
  1108. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  1109. encoder_outputs = self.encoder(
  1110. hidden_states,
  1111. attention_mask=attention_mask,
  1112. output_attentions=output_attentions,
  1113. output_hidden_states=output_hidden_states,
  1114. return_dict=return_dict,
  1115. )
  1116. hidden_states = encoder_outputs[0]
  1117. if not return_dict:
  1118. return (hidden_states,) + encoder_outputs[1:]
  1119. return BaseModelOutput(
  1120. last_hidden_state=hidden_states,
  1121. hidden_states=encoder_outputs.hidden_states,
  1122. attentions=encoder_outputs.attentions,
  1123. )
  1124. @add_start_docstrings(
  1125. """Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
  1126. HUBERT_START_DOCSTRING,
  1127. )
  1128. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
  1129. class HubertForCTC(HubertPreTrainedModel):
  1130. def __init__(self, config, target_lang: Optional[str] = None):
  1131. super().__init__(config)
  1132. self.hubert = HubertModel(config)
  1133. self.dropout = nn.Dropout(config.final_dropout)
  1134. self.target_lang = target_lang
  1135. if config.vocab_size is None:
  1136. raise ValueError(
  1137. f"You are trying to instantiate {self.__class__} with a configuration that "
  1138. "does not define the vocabulary size of the language model head. Please "
  1139. "instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  1140. "or define `vocab_size` of your model's configuration."
  1141. )
  1142. output_hidden_size = (
  1143. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1144. )
  1145. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  1146. # Initialize weights and apply final processing
  1147. self.post_init()
  1148. def tie_weights(self):
  1149. """
  1150. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  1151. passing `target_lang=...` to `from_pretrained(...)`.
  1152. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  1153. """
  1154. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  1155. # correctly load adapter layers for Hubert so that we do not have to introduce a new API to
  1156. # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
  1157. # ok to repurpose this function here.
  1158. target_lang = self.target_lang
  1159. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  1160. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  1161. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  1162. logger.info("By default `target_lang` is set to 'eng'.")
  1163. elif target_lang is not None:
  1164. self.load_adapter(target_lang, force_load=True)
  1165. def freeze_feature_extractor(self):
  1166. """
  1167. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1168. not be updated during training.
  1169. """
  1170. warnings.warn(
  1171. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1172. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1173. FutureWarning,
  1174. )
  1175. self.freeze_feature_encoder()
  1176. def freeze_feature_encoder(self):
  1177. """
  1178. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1179. not be updated during training.
  1180. """
  1181. self.hubert.feature_extractor._freeze_parameters()
  1182. def freeze_base_model(self):
  1183. """
  1184. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1185. be updated during training. Only the classification head will be updated.
  1186. """
  1187. for param in self.hubert.parameters():
  1188. param.requires_grad = False
  1189. @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
  1190. @add_code_sample_docstrings(
  1191. checkpoint=_CHECKPOINT_FOR_DOC,
  1192. output_type=CausalLMOutput,
  1193. config_class=_CONFIG_FOR_DOC,
  1194. expected_output=_CTC_EXPECTED_OUTPUT,
  1195. expected_loss=_CTC_EXPECTED_LOSS,
  1196. )
  1197. def forward(
  1198. self,
  1199. input_values: Optional[torch.Tensor],
  1200. attention_mask: Optional[torch.Tensor] = None,
  1201. output_attentions: Optional[bool] = None,
  1202. output_hidden_states: Optional[bool] = None,
  1203. return_dict: Optional[bool] = None,
  1204. labels: Optional[torch.Tensor] = None,
  1205. ) -> Union[Tuple, CausalLMOutput]:
  1206. r"""
  1207. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1208. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1209. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1210. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1211. config.vocab_size - 1]`.
  1212. """
  1213. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1214. if labels is not None and labels.max() >= self.config.vocab_size:
  1215. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1216. outputs = self.hubert(
  1217. input_values,
  1218. attention_mask=attention_mask,
  1219. output_attentions=output_attentions,
  1220. output_hidden_states=output_hidden_states,
  1221. return_dict=return_dict,
  1222. )
  1223. hidden_states = outputs[0]
  1224. hidden_states = self.dropout(hidden_states)
  1225. logits = self.lm_head(hidden_states)
  1226. loss = None
  1227. if labels is not None:
  1228. # retrieve loss input_lengths from attention_mask
  1229. attention_mask = (
  1230. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1231. )
  1232. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1233. # assuming that padded tokens are filled with -100
  1234. # when not being attended to
  1235. labels_mask = labels >= 0
  1236. target_lengths = labels_mask.sum(-1)
  1237. flattened_targets = labels.masked_select(labels_mask)
  1238. # ctc_loss doesn't support fp16
  1239. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1240. with torch.backends.cudnn.flags(enabled=False):
  1241. loss = nn.functional.ctc_loss(
  1242. log_probs,
  1243. flattened_targets,
  1244. input_lengths,
  1245. target_lengths,
  1246. blank=self.config.pad_token_id,
  1247. reduction=self.config.ctc_loss_reduction,
  1248. zero_infinity=self.config.ctc_zero_infinity,
  1249. )
  1250. if not return_dict:
  1251. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1252. return ((loss,) + output) if loss is not None else output
  1253. return CausalLMOutput(
  1254. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1255. )
  1256. @add_start_docstrings(
  1257. """
  1258. Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1259. SUPERB Keyword Spotting.
  1260. """,
  1261. HUBERT_START_DOCSTRING,
  1262. )
  1263. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->Hubert, wav2vec2->hubert, WAV_2_VEC_2->HUBERT
  1264. class HubertForSequenceClassification(HubertPreTrainedModel):
  1265. def __init__(self, config):
  1266. super().__init__(config)
  1267. if hasattr(config, "add_adapter") and config.add_adapter:
  1268. raise ValueError(
  1269. "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
  1270. )
  1271. self.hubert = HubertModel(config)
  1272. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1273. if config.use_weighted_layer_sum:
  1274. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1275. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1276. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1277. # Initialize weights and apply final processing
  1278. self.post_init()
  1279. def freeze_feature_extractor(self):
  1280. """
  1281. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1282. not be updated during training.
  1283. """
  1284. warnings.warn(
  1285. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1286. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1287. FutureWarning,
  1288. )
  1289. self.freeze_feature_encoder()
  1290. def freeze_feature_encoder(self):
  1291. """
  1292. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1293. not be updated during training.
  1294. """
  1295. self.hubert.feature_extractor._freeze_parameters()
  1296. def freeze_base_model(self):
  1297. """
  1298. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1299. be updated during training. Only the classification head will be updated.
  1300. """
  1301. for param in self.hubert.parameters():
  1302. param.requires_grad = False
  1303. @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
  1304. @add_code_sample_docstrings(
  1305. checkpoint=_SEQ_CLASS_CHECKPOINT,
  1306. output_type=SequenceClassifierOutput,
  1307. config_class=_CONFIG_FOR_DOC,
  1308. modality="audio",
  1309. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1310. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1311. )
  1312. def forward(
  1313. self,
  1314. input_values: Optional[torch.Tensor],
  1315. attention_mask: Optional[torch.Tensor] = None,
  1316. output_attentions: Optional[bool] = None,
  1317. output_hidden_states: Optional[bool] = None,
  1318. return_dict: Optional[bool] = None,
  1319. labels: Optional[torch.Tensor] = None,
  1320. ) -> Union[Tuple, SequenceClassifierOutput]:
  1321. r"""
  1322. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1323. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1324. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1325. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1326. """
  1327. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1328. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1329. outputs = self.hubert(
  1330. input_values,
  1331. attention_mask=attention_mask,
  1332. output_attentions=output_attentions,
  1333. output_hidden_states=output_hidden_states,
  1334. return_dict=return_dict,
  1335. )
  1336. if self.config.use_weighted_layer_sum:
  1337. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1338. hidden_states = torch.stack(hidden_states, dim=1)
  1339. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1340. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1341. else:
  1342. hidden_states = outputs[0]
  1343. hidden_states = self.projector(hidden_states)
  1344. if attention_mask is None:
  1345. pooled_output = hidden_states.mean(dim=1)
  1346. else:
  1347. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1348. hidden_states[~padding_mask] = 0.0
  1349. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1350. logits = self.classifier(pooled_output)
  1351. loss = None
  1352. if labels is not None:
  1353. loss_fct = CrossEntropyLoss()
  1354. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1355. if not return_dict:
  1356. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1357. return ((loss,) + output) if loss is not None else output
  1358. return SequenceClassifierOutput(
  1359. loss=loss,
  1360. logits=logits,
  1361. hidden_states=outputs.hidden_states,
  1362. attentions=outputs.attentions,
  1363. )