modeling_tf_hubert.py 69 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """TensorFlow Hubert model."""
  16. from __future__ import annotations
  17. import warnings
  18. from typing import Any, Optional, Tuple, Union
  19. import numpy as np
  20. import tensorflow as tf
  21. from ...activations_tf import get_tf_activation
  22. from ...modeling_tf_outputs import TFBaseModelOutput, TFCausalLMOutput
  23. from ...modeling_tf_utils import (
  24. TFPreTrainedModel,
  25. get_initializer,
  26. keras,
  27. keras_serializable,
  28. unpack_inputs,
  29. )
  30. from ...tf_utils import shape_list, stable_softmax
  31. from ...utils import (
  32. add_start_docstrings,
  33. add_start_docstrings_to_model_forward,
  34. logging,
  35. replace_return_docstrings,
  36. )
  37. from .configuration_hubert import HubertConfig
  38. logger = logging.get_logger(__name__)
  39. _CONFIG_FOR_DOC = "HubertConfig"
  40. LARGE_NEGATIVE = -1e8
  41. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._sample_without_replacement
  42. def _sample_without_replacement(distribution, num_samples):
  43. """
  44. Categorical sampling without replacement is currently not implemented. The gumbel-max trick will do for now - see
  45. https://github.com/tensorflow/tensorflow/issues/9260 for more info
  46. """
  47. z = -tf.math.log(tf.random.uniform(shape_list(distribution), 0, 1))
  48. _, indices = tf.nn.top_k(distribution + z, num_samples)
  49. return indices
  50. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._scatter_values_on_batch_indices
  51. def _scatter_values_on_batch_indices(values, batch_indices, output_shape):
  52. """
  53. Scatter function as in PyTorch with indices in format (batch_dim, indixes)
  54. """
  55. indices_shape = shape_list(batch_indices)
  56. # broadcast batch dim to indices_shape
  57. broad_casted_batch_dims = tf.reshape(
  58. tf.broadcast_to(tf.expand_dims(tf.range(indices_shape[0]), axis=-1), indices_shape), [1, -1]
  59. )
  60. # transform batch_indices to pair_indices
  61. pair_indices = tf.transpose(tf.concat([broad_casted_batch_dims, tf.reshape(batch_indices, [1, -1])], 0))
  62. # scatter values to pair indices
  63. return tf.scatter_nd(pair_indices, tf.reshape(values, [-1]), output_shape)
  64. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2._compute_mask_indices
  65. def _compute_mask_indices(
  66. shape: Tuple[int, int],
  67. mask_prob: float,
  68. mask_length: int,
  69. min_masks: int = 0,
  70. ) -> tf.Tensor:
  71. """
  72. Computes random mask spans for a given shape
  73. Args:
  74. shape: the shape for which to compute masks.
  75. should be of size 2 where first element is batch size and 2nd is timesteps
  76. attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
  77. mask_prob:
  78. probability for each token to be chosen as start of the span to be masked. this will be multiplied by
  79. number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
  80. however due to overlaps, the actual number will be smaller (unless no_overlap is True)
  81. mask_length: size of the mask
  82. min_masks: minimum number of masked spans
  83. Adapted from [fairseq's
  84. data_utils.py](https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376).
  85. """
  86. batch_size, sequence_length = shape
  87. if mask_length < 1:
  88. raise ValueError("`mask_length` has to be bigger than 0.")
  89. tf.debugging.assert_less(
  90. mask_length,
  91. sequence_length,
  92. message=(
  93. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and"
  94. f" `sequence_length`: {sequence_length}`"
  95. ),
  96. )
  97. # compute number of masked spans in batch
  98. num_masked_spans = mask_prob * tf.cast(sequence_length, tf.float32) / mask_length + tf.random.uniform((1,))
  99. num_masked_spans = tf.maximum(num_masked_spans, min_masks)
  100. num_masked_spans = tf.cast(num_masked_spans, tf.int32)
  101. # make sure num masked indices <= sequence_length
  102. num_masked_spans = tf.math.minimum(sequence_length // mask_length, num_masked_spans)
  103. num_masked_spans = tf.squeeze(num_masked_spans)
  104. # SpecAugment mask to fill
  105. spec_aug_mask = tf.zeros((batch_size, sequence_length), dtype=tf.int32)
  106. # uniform distribution to sample from, make sure that offset samples are < sequence_length
  107. uniform_dist = tf.ones((batch_size, sequence_length - (mask_length - 1)))
  108. # get random indices to mask
  109. spec_aug_mask_idxs = _sample_without_replacement(uniform_dist, num_masked_spans)
  110. # expand masked indices to masked spans
  111. spec_aug_mask_idxs = tf.expand_dims(spec_aug_mask_idxs, -1)
  112. spec_aug_mask_idxs = tf.tile(spec_aug_mask_idxs, (1, 1, mask_length))
  113. spec_aug_mask_idxs = tf.reshape(spec_aug_mask_idxs, (batch_size, num_masked_spans * mask_length))
  114. offsets = tf.range(mask_length)[tf.newaxis, tf.newaxis, :]
  115. offsets = tf.tile(offsets, (batch_size, num_masked_spans, 1))
  116. offsets = tf.reshape(offsets, (batch_size, num_masked_spans * mask_length))
  117. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  118. # scatter indices to mask
  119. spec_aug_mask = _scatter_values_on_batch_indices(
  120. tf.ones_like(spec_aug_mask_idxs), spec_aug_mask_idxs, tf.shape(spec_aug_mask)
  121. )
  122. return spec_aug_mask
  123. # Copied from transformers.models.bart.modeling_tf_bart._expand_mask
  124. def _expand_mask(mask: tf.Tensor, tgt_len: Optional[int] = None):
  125. """
  126. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  127. """
  128. src_len = shape_list(mask)[1]
  129. tgt_len = tgt_len if tgt_len is not None else src_len
  130. one_cst = tf.constant(1.0)
  131. mask = tf.cast(mask, dtype=one_cst.dtype)
  132. expanded_mask = tf.tile(mask[:, None, None, :], (1, 1, tgt_len, 1))
  133. return (one_cst - expanded_mask) * LARGE_NEGATIVE
  134. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNorm with Wav2Vec2->Hubert
  135. class TFHubertGroupNorm(keras.layers.Layer):
  136. """
  137. From tensorflow-addons https://www.tensorflow.org/addons/api_docs/python/tfa/layers/GroupNormalization
  138. """
  139. def __init__(
  140. self,
  141. groups: int = 32,
  142. axis: int = -1,
  143. epsilon: float = 1e-3,
  144. center: bool = True,
  145. scale: bool = True,
  146. beta_initializer: keras.initializers.Initializer = "zeros",
  147. gamma_initializer: keras.initializers.Initializer = "ones",
  148. beta_regularizer: keras.regularizers.Regularizer = None,
  149. gamma_regularizer: keras.regularizers.Regularizer = None,
  150. beta_constraint: keras.constraints.Constraint = None,
  151. gamma_constraint: keras.constraints.Constraint = None,
  152. **kwargs,
  153. ):
  154. super().__init__(**kwargs)
  155. self.supports_masking = True
  156. self.groups = groups
  157. self.axis = axis
  158. self.epsilon = epsilon
  159. self.center = center
  160. self.scale = scale
  161. self.beta_initializer = keras.initializers.get(beta_initializer)
  162. self.gamma_initializer = keras.initializers.get(gamma_initializer)
  163. self.beta_regularizer = keras.regularizers.get(beta_regularizer)
  164. self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
  165. self.beta_constraint = keras.constraints.get(beta_constraint)
  166. self.gamma_constraint = keras.constraints.get(gamma_constraint)
  167. self._check_axis()
  168. def build(self, input_shape):
  169. self._check_if_input_shape_is_none(input_shape)
  170. self._set_number_of_groups_for_instance_norm(input_shape)
  171. self._check_size_of_dimensions(input_shape)
  172. self._create_input_spec(input_shape)
  173. self._add_gamma_weight(input_shape)
  174. self._add_beta_weight(input_shape)
  175. self.built = True
  176. super().build(input_shape)
  177. def call(self, inputs):
  178. input_shape = keras.backend.int_shape(inputs)
  179. tensor_input_shape = tf.shape(inputs)
  180. reshaped_inputs, group_shape = self._reshape_into_groups(inputs, input_shape, tensor_input_shape)
  181. normalized_inputs = self._apply_normalization(reshaped_inputs, input_shape)
  182. is_instance_norm = (input_shape[self.axis] // self.groups) == 1
  183. if not is_instance_norm:
  184. outputs = tf.reshape(normalized_inputs, tensor_input_shape)
  185. else:
  186. outputs = normalized_inputs
  187. return outputs
  188. def get_config(self):
  189. config = {
  190. "groups": self.groups,
  191. "axis": self.axis,
  192. "epsilon": self.epsilon,
  193. "center": self.center,
  194. "scale": self.scale,
  195. "beta_initializer": keras.initializers.serialize(self.beta_initializer),
  196. "gamma_initializer": keras.initializers.serialize(self.gamma_initializer),
  197. "beta_regularizer": keras.regularizers.serialize(self.beta_regularizer),
  198. "gamma_regularizer": keras.regularizers.serialize(self.gamma_regularizer),
  199. "beta_constraint": keras.constraints.serialize(self.beta_constraint),
  200. "gamma_constraint": keras.constraints.serialize(self.gamma_constraint),
  201. }
  202. base_config = super().get_config()
  203. return {**base_config, **config}
  204. def compute_output_shape(self, input_shape):
  205. return input_shape
  206. def _reshape_into_groups(self, inputs, input_shape, tensor_input_shape):
  207. group_shape = [tensor_input_shape[i] for i in range(len(input_shape))]
  208. is_instance_norm = (input_shape[self.axis] // self.groups) == 1
  209. if not is_instance_norm:
  210. group_shape[self.axis] = input_shape[self.axis] // self.groups
  211. group_shape.insert(self.axis, self.groups)
  212. group_shape = tf.stack(group_shape)
  213. reshaped_inputs = tf.reshape(inputs, group_shape)
  214. return reshaped_inputs, group_shape
  215. else:
  216. return inputs, group_shape
  217. def _apply_normalization(self, reshaped_inputs, input_shape):
  218. group_shape = keras.backend.int_shape(reshaped_inputs)
  219. group_reduction_axes = list(range(1, len(group_shape)))
  220. is_instance_norm = (input_shape[self.axis] // self.groups) == 1
  221. if not is_instance_norm:
  222. axis = -2 if self.axis == -1 else self.axis - 1
  223. else:
  224. axis = -1 if self.axis == -1 else self.axis - 1
  225. group_reduction_axes.pop(axis)
  226. mean, variance = tf.nn.moments(reshaped_inputs, group_reduction_axes, keepdims=True)
  227. gamma, beta = self._get_reshaped_weights(input_shape)
  228. normalized_inputs = tf.nn.batch_normalization(
  229. reshaped_inputs,
  230. mean=mean,
  231. variance=variance,
  232. scale=gamma,
  233. offset=beta,
  234. variance_epsilon=self.epsilon,
  235. )
  236. return normalized_inputs
  237. def _get_reshaped_weights(self, input_shape):
  238. broadcast_shape = self._create_broadcast_shape(input_shape)
  239. gamma = None
  240. beta = None
  241. if self.scale:
  242. gamma = tf.reshape(self.gamma, broadcast_shape)
  243. if self.center:
  244. beta = tf.reshape(self.beta, broadcast_shape)
  245. return gamma, beta
  246. def _check_if_input_shape_is_none(self, input_shape):
  247. dim = input_shape[self.axis]
  248. if dim is None:
  249. raise ValueError(
  250. "Axis "
  251. + str(self.axis)
  252. + " of input tensor should have a defined dimension but the layer received an input with shape "
  253. + str(input_shape)
  254. + "."
  255. )
  256. def _set_number_of_groups_for_instance_norm(self, input_shape):
  257. dim = input_shape[self.axis]
  258. if self.groups == -1:
  259. self.groups = dim
  260. def _check_size_of_dimensions(self, input_shape):
  261. dim = input_shape[self.axis]
  262. if dim < self.groups:
  263. raise ValueError(
  264. "Number of groups ("
  265. + str(self.groups)
  266. + ") cannot be more than the number of channels ("
  267. + str(dim)
  268. + ")."
  269. )
  270. if dim % self.groups != 0:
  271. raise ValueError(
  272. "Number of groups ("
  273. + str(self.groups)
  274. + ") must be a multiple of the number of channels ("
  275. + str(dim)
  276. + ")."
  277. )
  278. def _check_axis(self):
  279. if self.axis == 0:
  280. raise ValueError(
  281. "You are trying to normalize your batch axis. Do you want to use tf.layer.batch_normalization instead"
  282. )
  283. def _create_input_spec(self, input_shape):
  284. dim = input_shape[self.axis]
  285. self.input_spec = keras.layers.InputSpec(ndim=len(input_shape), axes={self.axis: dim})
  286. def _add_gamma_weight(self, input_shape):
  287. dim = input_shape[self.axis]
  288. shape = (dim,)
  289. if self.scale:
  290. self.gamma = self.add_weight(
  291. shape=shape,
  292. name="gamma",
  293. initializer=self.gamma_initializer,
  294. regularizer=self.gamma_regularizer,
  295. constraint=self.gamma_constraint,
  296. )
  297. else:
  298. self.gamma = None
  299. def _add_beta_weight(self, input_shape):
  300. dim = input_shape[self.axis]
  301. shape = (dim,)
  302. if self.center:
  303. self.beta = self.add_weight(
  304. shape=shape,
  305. name="beta",
  306. initializer=self.beta_initializer,
  307. regularizer=self.beta_regularizer,
  308. constraint=self.beta_constraint,
  309. )
  310. else:
  311. self.beta = None
  312. def _create_broadcast_shape(self, input_shape):
  313. broadcast_shape = [1] * len(input_shape)
  314. is_instance_norm = (input_shape[self.axis] // self.groups) == 1
  315. if not is_instance_norm:
  316. broadcast_shape[self.axis] = input_shape[self.axis] // self.groups
  317. broadcast_shape.insert(self.axis, self.groups)
  318. else:
  319. broadcast_shape[self.axis] = self.groups
  320. return broadcast_shape
  321. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2WeightNormConv1D with Wav2Vec2->Hubert
  322. class TFHubertWeightNormConv1D(keras.layers.Conv1D):
  323. """Adapted from https://www.tensorflow.org/probability/api_docs/python/tfp/layers/weight_norm/WeightNorm"""
  324. def __init__(self, filters, kernel_size, groups, explicit_padding, **kwargs):
  325. super().__init__(
  326. filters=filters,
  327. kernel_size=kernel_size,
  328. groups=groups,
  329. padding="valid",
  330. use_bias=True,
  331. bias_initializer="he_normal",
  332. **kwargs,
  333. )
  334. self.explicit_padding = explicit_padding
  335. self.filter_axis = 2
  336. self.kernel_norm_axes = tf.constant([0, 1])
  337. def _init_norm(self):
  338. """Set the norm of the weight vector."""
  339. kernel_norm = tf.sqrt(tf.reduce_sum(tf.square(self.weight_v), axis=self.kernel_norm_axes))
  340. self.weight_g.assign(kernel_norm[:, tf.newaxis, tf.newaxis])
  341. def _normalize_kernel(self):
  342. """Generate normalized weights."""
  343. kernel = tf.nn.l2_normalize(self.weight_v, axis=self.kernel_norm_axes) * tf.transpose(self.weight_g)
  344. self.kernel = tf.transpose(kernel)
  345. def build(self, input_shape):
  346. if not self.built:
  347. super().build(input_shape)
  348. self.kernel = tf.Variable(tf.transpose(self.kernel), name="weight_v", trainable=True)
  349. self.weight_v = self.kernel
  350. self.weight_g = self.add_weight(
  351. name="weight_g",
  352. shape=(int(self.weight_v.shape[self.filter_axis]), 1, 1),
  353. initializer="ones",
  354. dtype=self.weight_v.dtype,
  355. trainable=True,
  356. )
  357. self._init_norm()
  358. self.bias = self.add_weight(name="bias", shape=(self.filters,), initializer="zeros", trainable=True)
  359. def call(self, inputs):
  360. # TODO Matt: Assigning to attributes in call() is deeply sinful in TensorFlow, as it should be idempotent.
  361. # This whole layer should be replaced by a layer that doesn't inherit from Conv1D, but instead calls
  362. # a functional 1d convolution with normalized weights that it generates (but does not store!)
  363. self._normalize_kernel()
  364. padded_inputs = tf.pad(inputs, ((0, 0), (self.explicit_padding, self.explicit_padding), (0, 0)))
  365. output = super().call(padded_inputs)
  366. return output
  367. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2NoLayerNormConvLayer with Wav2Vec2->Hubert
  368. class TFHubertNoLayerNormConvLayer(keras.layers.Layer):
  369. def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
  370. super().__init__(**kwargs)
  371. self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
  372. self.out_conv_dim = config.conv_dim[layer_id]
  373. self.conv = keras.layers.Conv1D(
  374. filters=self.out_conv_dim,
  375. kernel_size=config.conv_kernel[layer_id],
  376. strides=config.conv_stride[layer_id],
  377. use_bias=config.conv_bias,
  378. name="conv",
  379. )
  380. self.activation = get_tf_activation(config.feat_extract_activation)
  381. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  382. hidden_states = self.conv(hidden_states)
  383. hidden_states = self.activation(hidden_states)
  384. return hidden_states
  385. def build(self, input_shape=None):
  386. if self.built:
  387. return
  388. self.built = True
  389. if getattr(self, "conv", None) is not None:
  390. with tf.name_scope(self.conv.name):
  391. self.conv.build([None, None, self.in_conv_dim])
  392. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2LayerNormConvLayer with Wav2Vec2->Hubert
  393. class TFHubertLayerNormConvLayer(keras.layers.Layer):
  394. def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
  395. super().__init__(**kwargs)
  396. self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
  397. self.out_conv_dim = config.conv_dim[layer_id]
  398. self.conv = keras.layers.Conv1D(
  399. filters=self.out_conv_dim,
  400. kernel_size=config.conv_kernel[layer_id],
  401. strides=config.conv_stride[layer_id],
  402. use_bias=config.conv_bias,
  403. name="conv",
  404. )
  405. self.layer_norm = keras.layers.LayerNormalization(name="layer_norm", epsilon=config.layer_norm_eps)
  406. self.activation = get_tf_activation(config.feat_extract_activation)
  407. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  408. hidden_states = self.conv(hidden_states)
  409. hidden_states = self.layer_norm(hidden_states)
  410. hidden_states = self.activation(hidden_states)
  411. return hidden_states
  412. def build(self, input_shape=None):
  413. if self.built:
  414. return
  415. self.built = True
  416. if getattr(self, "conv", None) is not None:
  417. with tf.name_scope(self.conv.name):
  418. self.conv.build([None, None, self.in_conv_dim])
  419. if getattr(self, "layer_norm", None) is not None:
  420. with tf.name_scope(self.layer_norm.name):
  421. self.layer_norm.build([None, None, self.out_conv_dim])
  422. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2GroupNormConvLayer with Wav2Vec2->Hubert
  423. class TFHubertGroupNormConvLayer(keras.layers.Layer):
  424. def __init__(self, config: HubertConfig, layer_id: int = 0, **kwargs: Any) -> None:
  425. super().__init__(**kwargs)
  426. self.in_conv_dim = config.conv_dim[layer_id] if layer_id > 0 else 1
  427. self.out_conv_dim = config.conv_dim[layer_id]
  428. self.conv = keras.layers.Conv1D(
  429. filters=self.out_conv_dim,
  430. kernel_size=config.conv_kernel[layer_id],
  431. strides=config.conv_stride[layer_id],
  432. use_bias=config.conv_bias,
  433. name="conv",
  434. )
  435. self.activation = get_tf_activation(config.feat_extract_activation)
  436. self.layer_norm = TFHubertGroupNorm(groups=self.out_conv_dim, epsilon=config.layer_norm_eps, name="layer_norm")
  437. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  438. hidden_states = self.conv(hidden_states)
  439. hidden_states = self.layer_norm(hidden_states)
  440. hidden_states = self.activation(hidden_states)
  441. return hidden_states
  442. def build(self, input_shape=None):
  443. if self.built:
  444. return
  445. self.built = True
  446. if getattr(self, "conv", None) is not None:
  447. with tf.name_scope(self.conv.name):
  448. self.conv.build([None, None, self.in_conv_dim])
  449. if getattr(self, "layer_norm", None) is not None:
  450. with tf.name_scope(self.layer_norm.name):
  451. self.layer_norm.build([None, None, self.out_conv_dim])
  452. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2PositionalConvEmbedding with Wav2Vec2->Hubert
  453. class TFHubertPositionalConvEmbedding(keras.layers.Layer):
  454. def __init__(self, config: HubertConfig, **kwargs: Any) -> None:
  455. super().__init__(**kwargs)
  456. self.conv = TFHubertWeightNormConv1D(
  457. filters=config.hidden_size,
  458. kernel_size=config.num_conv_pos_embeddings,
  459. groups=config.num_conv_pos_embedding_groups,
  460. explicit_padding=config.num_conv_pos_embeddings // 2,
  461. name="conv",
  462. )
  463. self.padding = TFHubertSamePadLayer(config.num_conv_pos_embeddings)
  464. self.activation = get_tf_activation(config.feat_extract_activation)
  465. self.config = config
  466. def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
  467. hidden_states = self.conv(hidden_states)
  468. hidden_states = self.padding(hidden_states)
  469. hidden_states = self.activation(hidden_states)
  470. return hidden_states
  471. def build(self, input_shape=None):
  472. if self.built:
  473. return
  474. self.built = True
  475. if getattr(self, "conv", None) is not None:
  476. with tf.name_scope(self.conv.name):
  477. self.conv.build([None, None, self.config.hidden_size])
  478. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2SamePadLayer with Wav2Vec2->Hubert
  479. class TFHubertSamePadLayer(keras.layers.Layer):
  480. def __init__(self, num_conv_pos_embeddings, **kwargs):
  481. super().__init__(**kwargs)
  482. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  483. def call(self, hidden_states):
  484. if self.num_pad_remove > 0:
  485. hidden_states = hidden_states[:, : -self.num_pad_remove, :]
  486. return hidden_states
  487. class TFHubertFeatureEncoder(keras.layers.Layer):
  488. def __init__(self, config: HubertConfig, **kwargs: Any) -> None:
  489. super().__init__(**kwargs)
  490. if config.feat_extract_norm == "group":
  491. conv_layers = [TFHubertGroupNormConvLayer(config, layer_id=0, name=f"conv_layers.{0}")] + [
  492. TFHubertNoLayerNormConvLayer(config, layer_id=i + 1, name=f"conv_layers.{i+1}")
  493. for i in range(config.num_feat_extract_layers - 1)
  494. ]
  495. elif config.feat_extract_norm == "layer":
  496. conv_layers = [
  497. TFHubertLayerNormConvLayer(config, layer_id=i, name=f"conv_layers.{i}")
  498. for i in range(config.num_feat_extract_layers)
  499. ]
  500. else:
  501. raise ValueError(
  502. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  503. )
  504. self.conv_layers = conv_layers
  505. def call(self, input_values):
  506. hidden_states = tf.expand_dims(input_values, -1)
  507. for conv_layer in self.conv_layers:
  508. hidden_states = conv_layer(hidden_states)
  509. return hidden_states
  510. def build(self, input_shape=None):
  511. if self.built:
  512. return
  513. self.built = True
  514. for conv_layer in self.conv_layers:
  515. with tf.name_scope(conv_layer.name):
  516. conv_layer.build(None)
  517. class TFHubertFeatureExtractor(TFHubertFeatureEncoder):
  518. def __init__(self, config, **kwargs):
  519. super().__init__(config, **kwargs)
  520. warnings.warn(
  521. f"The class `{self.__class__.__name__}` has been depreciated "
  522. "and will be removed in Transformers v5. "
  523. f"Use `{self.__class__.__bases__[0].__name__}` instead.",
  524. FutureWarning,
  525. )
  526. class TFHubertFeatureProjection(keras.layers.Layer):
  527. def __init__(self, config: HubertConfig, **kwargs):
  528. super().__init__(**kwargs)
  529. self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
  530. self.projection = keras.layers.Dense(
  531. units=config.hidden_size,
  532. kernel_initializer=get_initializer(config.initializer_range),
  533. bias_initializer="zeros",
  534. name="projection",
  535. )
  536. self.dropout = keras.layers.Dropout(rate=config.feat_proj_dropout)
  537. self.config = config
  538. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  539. hidden_states = self.layer_norm(hidden_states)
  540. hidden_states = self.projection(hidden_states)
  541. hidden_states = self.dropout(hidden_states, training=training)
  542. return hidden_states
  543. def build(self, input_shape=None):
  544. if self.built:
  545. return
  546. self.built = True
  547. if getattr(self, "layer_norm", None) is not None:
  548. with tf.name_scope(self.layer_norm.name):
  549. self.layer_norm.build([None, None, self.config.conv_dim[-1]])
  550. if getattr(self, "projection", None) is not None:
  551. with tf.name_scope(self.projection.name):
  552. self.projection.build([None, None, self.config.conv_dim[-1]])
  553. # Copied from transformers.models.bart.modeling_tf_bart.TFBartAttention with TFBart->TFHubert
  554. class TFHubertAttention(keras.layers.Layer):
  555. """Multi-headed attention from "Attention Is All You Need"""
  556. def __init__(
  557. self,
  558. embed_dim: int,
  559. num_heads: int,
  560. dropout: float = 0.0,
  561. is_decoder: bool = False,
  562. bias: bool = True,
  563. **kwargs,
  564. ):
  565. super().__init__(**kwargs)
  566. self.embed_dim = embed_dim
  567. self.num_heads = num_heads
  568. self.dropout = keras.layers.Dropout(dropout)
  569. self.head_dim = embed_dim // num_heads
  570. if (self.head_dim * num_heads) != self.embed_dim:
  571. raise ValueError(
  572. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  573. f" and `num_heads`: {num_heads})."
  574. )
  575. self.scaling = self.head_dim**-0.5
  576. self.is_decoder = is_decoder
  577. self.k_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="k_proj")
  578. self.q_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="q_proj")
  579. self.v_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="v_proj")
  580. self.out_proj = keras.layers.Dense(embed_dim, use_bias=bias, name="out_proj")
  581. def _shape(self, tensor: tf.Tensor, seq_len: int, bsz: int):
  582. return tf.transpose(tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)), (0, 2, 1, 3))
  583. def call(
  584. self,
  585. hidden_states: tf.Tensor,
  586. key_value_states: tf.Tensor | None = None,
  587. past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
  588. attention_mask: tf.Tensor | None = None,
  589. layer_head_mask: tf.Tensor | None = None,
  590. training: Optional[bool] = False,
  591. ) -> Tuple[tf.Tensor, tf.Tensor | None]:
  592. """Input shape: Batch x Time x Channel"""
  593. # if key_value_states are provided this layer is used as a cross-attention layer
  594. # for the decoder
  595. is_cross_attention = key_value_states is not None
  596. bsz, tgt_len, embed_dim = shape_list(hidden_states)
  597. # get query proj
  598. query_states = self.q_proj(hidden_states) * self.scaling
  599. # get key, value proj
  600. if is_cross_attention and past_key_value is not None:
  601. # reuse k,v, cross_attentions
  602. key_states = past_key_value[0]
  603. value_states = past_key_value[1]
  604. elif is_cross_attention:
  605. # cross_attentions
  606. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  607. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  608. elif past_key_value is not None:
  609. # reuse k, v, self_attention
  610. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  611. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  612. key_states = tf.concat([past_key_value[0], key_states], axis=2)
  613. value_states = tf.concat([past_key_value[1], value_states], axis=2)
  614. else:
  615. # self_attention
  616. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  617. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  618. if self.is_decoder:
  619. # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
  620. # Further calls to cross_attention layer can then reuse all cross-attention
  621. # key/value_states (first "if" case)
  622. # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
  623. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  624. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  625. # if encoder bi-directional self-attention `past_key_value` is always `None`
  626. past_key_value = (key_states, value_states)
  627. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  628. query_states = tf.reshape(self._shape(query_states, tgt_len, bsz), proj_shape)
  629. key_states = tf.reshape(key_states, proj_shape)
  630. value_states = tf.reshape(value_states, proj_shape)
  631. src_len = shape_list(key_states)[1]
  632. attn_weights = tf.matmul(query_states, key_states, transpose_b=True)
  633. tf.debugging.assert_equal(
  634. shape_list(attn_weights),
  635. [bsz * self.num_heads, tgt_len, src_len],
  636. message=(
  637. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  638. f" {shape_list(attn_weights)}"
  639. ),
  640. )
  641. if attention_mask is not None:
  642. tf.debugging.assert_equal(
  643. shape_list(attention_mask),
  644. [bsz, 1, tgt_len, src_len],
  645. message=(
  646. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
  647. f" {shape_list(attention_mask)}"
  648. ),
  649. )
  650. attention_mask = tf.cast(attention_mask, dtype=attn_weights.dtype)
  651. attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attention_mask
  652. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  653. attn_weights = stable_softmax(attn_weights, axis=-1)
  654. if layer_head_mask is not None:
  655. tf.debugging.assert_equal(
  656. shape_list(layer_head_mask),
  657. [self.num_heads],
  658. message=(
  659. f"Head mask for a single layer should be of size {(self.num_heads)}, but is"
  660. f" {shape_list(layer_head_mask)}"
  661. ),
  662. )
  663. attn_weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * tf.reshape(
  664. attn_weights, (bsz, self.num_heads, tgt_len, src_len)
  665. )
  666. attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
  667. attn_probs = self.dropout(attn_weights, training=training)
  668. attn_output = tf.matmul(attn_probs, value_states)
  669. tf.debugging.assert_equal(
  670. shape_list(attn_output),
  671. [bsz * self.num_heads, tgt_len, self.head_dim],
  672. message=(
  673. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  674. f" {shape_list(attn_output)}"
  675. ),
  676. )
  677. attn_output = tf.transpose(
  678. tf.reshape(attn_output, (bsz, self.num_heads, tgt_len, self.head_dim)), (0, 2, 1, 3)
  679. )
  680. attn_output = tf.reshape(attn_output, (bsz, tgt_len, embed_dim))
  681. attn_output = self.out_proj(attn_output)
  682. attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
  683. return attn_output, attn_weights, past_key_value
  684. def build(self, input_shape=None):
  685. if self.built:
  686. return
  687. self.built = True
  688. if getattr(self, "k_proj", None) is not None:
  689. with tf.name_scope(self.k_proj.name):
  690. self.k_proj.build([None, None, self.embed_dim])
  691. if getattr(self, "q_proj", None) is not None:
  692. with tf.name_scope(self.q_proj.name):
  693. self.q_proj.build([None, None, self.embed_dim])
  694. if getattr(self, "v_proj", None) is not None:
  695. with tf.name_scope(self.v_proj.name):
  696. self.v_proj.build([None, None, self.embed_dim])
  697. if getattr(self, "out_proj", None) is not None:
  698. with tf.name_scope(self.out_proj.name):
  699. self.out_proj.build([None, None, self.embed_dim])
  700. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2FeedForward with Wav2Vec2->Hubert
  701. class TFHubertFeedForward(keras.layers.Layer):
  702. def __init__(self, config: HubertConfig, **kwargs):
  703. super().__init__(**kwargs)
  704. self.intermediate_dropout = keras.layers.Dropout(config.activation_dropout)
  705. self.intermediate_dense = keras.layers.Dense(
  706. units=config.intermediate_size,
  707. kernel_initializer=get_initializer(config.initializer_range),
  708. bias_initializer="zeros",
  709. name="intermediate_dense",
  710. )
  711. self.intermediate_act_fn = get_tf_activation(config.hidden_act)
  712. self.output_dense = keras.layers.Dense(
  713. units=config.hidden_size,
  714. kernel_initializer=get_initializer(config.initializer_range),
  715. bias_initializer="zeros",
  716. name="output_dense",
  717. )
  718. self.output_dropout = keras.layers.Dropout(config.hidden_dropout)
  719. self.config = config
  720. def call(self, hidden_states: tf.Tensor, training: bool = False) -> tf.Tensor:
  721. hidden_states = self.intermediate_dense(hidden_states)
  722. hidden_states = self.intermediate_act_fn(hidden_states)
  723. hidden_states = self.intermediate_dropout(hidden_states, training=training)
  724. hidden_states = self.output_dense(hidden_states)
  725. hidden_states = self.output_dropout(hidden_states, training=training)
  726. return hidden_states
  727. def build(self, input_shape=None):
  728. if self.built:
  729. return
  730. self.built = True
  731. if getattr(self, "intermediate_dense", None) is not None:
  732. with tf.name_scope(self.intermediate_dense.name):
  733. self.intermediate_dense.build([None, None, self.config.hidden_size])
  734. if getattr(self, "output_dense", None) is not None:
  735. with tf.name_scope(self.output_dense.name):
  736. self.output_dense.build([None, None, self.config.intermediate_size])
  737. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayer with Wav2Vec2->Hubert
  738. class TFHubertEncoderLayer(keras.layers.Layer):
  739. def __init__(self, config: HubertConfig, **kwargs):
  740. super().__init__(**kwargs)
  741. self.attention = TFHubertAttention(
  742. embed_dim=config.hidden_size,
  743. num_heads=config.num_attention_heads,
  744. dropout=config.attention_dropout,
  745. is_decoder=False,
  746. name="attention",
  747. )
  748. self.dropout = keras.layers.Dropout(config.hidden_dropout)
  749. self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
  750. self.feed_forward = TFHubertFeedForward(config, name="feed_forward")
  751. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
  752. self.config = config
  753. def call(
  754. self,
  755. hidden_states: tf.Tensor,
  756. attention_mask: tf.Tensor | None = None,
  757. output_attentions: Optional[bool] = False,
  758. training: bool = False,
  759. ) -> Tuple[tf.Tensor]:
  760. attn_residual = hidden_states
  761. hidden_states, attn_weights, _ = self.attention(
  762. hidden_states, attention_mask=attention_mask, training=training
  763. )
  764. hidden_states = self.dropout(hidden_states, training=training)
  765. hidden_states = attn_residual + hidden_states
  766. hidden_states = self.layer_norm(hidden_states)
  767. hidden_states = hidden_states + self.feed_forward(hidden_states)
  768. hidden_states = self.final_layer_norm(hidden_states)
  769. outputs = (hidden_states,)
  770. if output_attentions:
  771. outputs += (attn_weights,)
  772. return outputs
  773. def build(self, input_shape=None):
  774. if self.built:
  775. return
  776. self.built = True
  777. if getattr(self, "attention", None) is not None:
  778. with tf.name_scope(self.attention.name):
  779. self.attention.build(None)
  780. if getattr(self, "layer_norm", None) is not None:
  781. with tf.name_scope(self.layer_norm.name):
  782. self.layer_norm.build([None, None, self.config.hidden_size])
  783. if getattr(self, "feed_forward", None) is not None:
  784. with tf.name_scope(self.feed_forward.name):
  785. self.feed_forward.build(None)
  786. if getattr(self, "final_layer_norm", None) is not None:
  787. with tf.name_scope(self.final_layer_norm.name):
  788. self.final_layer_norm.build([None, None, self.config.hidden_size])
  789. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderLayerStableLayerNorm with Wav2Vec2->Hubert
  790. class TFHubertEncoderLayerStableLayerNorm(keras.layers.Layer):
  791. def __init__(self, config: HubertConfig, **kwargs):
  792. super().__init__(**kwargs)
  793. self.attention = TFHubertAttention(
  794. embed_dim=config.hidden_size,
  795. num_heads=config.num_attention_heads,
  796. dropout=config.attention_dropout,
  797. is_decoder=False,
  798. name="attention",
  799. )
  800. self.dropout = keras.layers.Dropout(config.hidden_dropout)
  801. self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
  802. self.feed_forward = TFHubertFeedForward(config, name="feed_forward")
  803. self.final_layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="final_layer_norm")
  804. self.config = config
  805. def call(
  806. self,
  807. hidden_states: tf.Tensor,
  808. attention_mask: tf.Tensor | None = None,
  809. output_attentions: Optional[bool] = False,
  810. training: bool = False,
  811. ) -> Tuple[tf.Tensor]:
  812. attn_residual = hidden_states
  813. hidden_states = self.layer_norm(hidden_states)
  814. hidden_states, attn_weights, _ = self.attention(
  815. hidden_states, attention_mask=attention_mask, training=training
  816. )
  817. hidden_states = self.dropout(hidden_states, training=training)
  818. hidden_states = attn_residual + hidden_states
  819. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  820. outputs = (hidden_states,)
  821. if output_attentions:
  822. outputs += (attn_weights,)
  823. return outputs
  824. def build(self, input_shape=None):
  825. if self.built:
  826. return
  827. self.built = True
  828. if getattr(self, "attention", None) is not None:
  829. with tf.name_scope(self.attention.name):
  830. self.attention.build(None)
  831. if getattr(self, "layer_norm", None) is not None:
  832. with tf.name_scope(self.layer_norm.name):
  833. self.layer_norm.build([None, None, self.config.hidden_size])
  834. if getattr(self, "feed_forward", None) is not None:
  835. with tf.name_scope(self.feed_forward.name):
  836. self.feed_forward.build(None)
  837. if getattr(self, "final_layer_norm", None) is not None:
  838. with tf.name_scope(self.final_layer_norm.name):
  839. self.final_layer_norm.build([None, None, self.config.hidden_size])
  840. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2Encoder with Wav2Vec2->Hubert
  841. class TFHubertEncoder(keras.layers.Layer):
  842. def __init__(self, config: HubertConfig, **kwargs):
  843. super().__init__(**kwargs)
  844. self.config = config
  845. self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed")
  846. self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
  847. self.dropout = keras.layers.Dropout(config.hidden_dropout)
  848. self.layer = [TFHubertEncoderLayer(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)]
  849. def call(
  850. self,
  851. hidden_states: tf.Tensor,
  852. attention_mask: tf.Tensor | None = None,
  853. output_attentions: Optional[bool] = False,
  854. output_hidden_states: Optional[bool] = False,
  855. return_dict: Optional[bool] = True,
  856. training: Optional[bool] = False,
  857. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  858. all_hidden_states = () if output_hidden_states else None
  859. all_self_attentions = () if output_attentions else None
  860. if attention_mask is not None:
  861. hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
  862. attention_mask = _expand_mask(attention_mask)
  863. else:
  864. attention_mask = None
  865. position_embeddings = self.pos_conv_embed(hidden_states)
  866. hidden_states = hidden_states + position_embeddings
  867. hidden_states = self.layer_norm(hidden_states)
  868. hidden_states = self.dropout(hidden_states, training=training)
  869. for i, layer_module in enumerate(self.layer):
  870. if output_hidden_states:
  871. all_hidden_states = all_hidden_states + (hidden_states,)
  872. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  873. dropout_probability = np.random.uniform(0, 1)
  874. if training and (dropout_probability < self.config.layerdrop): # skip the layer
  875. continue
  876. layer_outputs = layer_module(
  877. hidden_states=hidden_states,
  878. attention_mask=attention_mask,
  879. output_attentions=output_attentions,
  880. training=training,
  881. )
  882. hidden_states = layer_outputs[0]
  883. if output_attentions:
  884. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  885. # Add last layer
  886. if output_hidden_states:
  887. all_hidden_states = all_hidden_states + (hidden_states,)
  888. if not return_dict:
  889. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  890. return TFBaseModelOutput(
  891. last_hidden_state=hidden_states,
  892. hidden_states=all_hidden_states,
  893. attentions=all_self_attentions,
  894. )
  895. def build(self, input_shape=None):
  896. if self.built:
  897. return
  898. self.built = True
  899. if getattr(self, "pos_conv_embed", None) is not None:
  900. with tf.name_scope(self.pos_conv_embed.name):
  901. self.pos_conv_embed.build(None)
  902. if getattr(self, "layer_norm", None) is not None:
  903. with tf.name_scope(self.layer_norm.name):
  904. self.layer_norm.build([None, None, self.config.hidden_size])
  905. if getattr(self, "layer", None) is not None:
  906. for layer in self.layer:
  907. with tf.name_scope(layer.name):
  908. layer.build(None)
  909. # Copied from transformers.models.wav2vec2.modeling_tf_wav2vec2.TFWav2Vec2EncoderStableLayerNorm with Wav2Vec2->Hubert
  910. class TFHubertEncoderStableLayerNorm(keras.layers.Layer):
  911. def __init__(self, config: HubertConfig, **kwargs):
  912. super().__init__(**kwargs)
  913. self.config = config
  914. self.pos_conv_embed = TFHubertPositionalConvEmbedding(config, name="pos_conv_embed")
  915. self.layer_norm = keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
  916. self.dropout = keras.layers.Dropout(config.hidden_dropout)
  917. self.layer = [
  918. TFHubertEncoderLayerStableLayerNorm(config, name=f"layers.{i}") for i in range(config.num_hidden_layers)
  919. ]
  920. def call(
  921. self,
  922. hidden_states: tf.Tensor,
  923. attention_mask: tf.Tensor | None = None,
  924. output_attentions: Optional[bool] = False,
  925. output_hidden_states: Optional[bool] = False,
  926. return_dict: Optional[bool] = True,
  927. training: Optional[bool] = False,
  928. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  929. all_hidden_states = () if output_hidden_states else None
  930. all_self_attentions = () if output_attentions else None
  931. if attention_mask is not None:
  932. hidden_states = hidden_states * tf.expand_dims(attention_mask, -1)
  933. attention_mask = _expand_mask(attention_mask)
  934. else:
  935. attention_mask = None
  936. position_embeddings = self.pos_conv_embed(hidden_states)
  937. hidden_states = hidden_states + position_embeddings
  938. hidden_states = self.dropout(hidden_states, training=training)
  939. for i, layer_module in enumerate(self.layer):
  940. if output_hidden_states:
  941. all_hidden_states = all_hidden_states + (hidden_states,)
  942. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  943. dropout_probability = np.random.uniform(0, 1)
  944. if training and (dropout_probability < self.config.layerdrop): # skip the layer
  945. continue
  946. layer_outputs = layer_module(
  947. hidden_states=hidden_states,
  948. attention_mask=attention_mask,
  949. output_attentions=output_attentions,
  950. training=training,
  951. )
  952. hidden_states = layer_outputs[0]
  953. if output_attentions:
  954. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  955. hidden_states = self.layer_norm(hidden_states)
  956. if output_hidden_states:
  957. all_hidden_states = all_hidden_states + (hidden_states,)
  958. if not return_dict:
  959. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  960. return TFBaseModelOutput(
  961. last_hidden_state=hidden_states,
  962. hidden_states=all_hidden_states,
  963. attentions=all_self_attentions,
  964. )
  965. def build(self, input_shape=None):
  966. if self.built:
  967. return
  968. self.built = True
  969. if getattr(self, "pos_conv_embed", None) is not None:
  970. with tf.name_scope(self.pos_conv_embed.name):
  971. self.pos_conv_embed.build(None)
  972. if getattr(self, "layer_norm", None) is not None:
  973. with tf.name_scope(self.layer_norm.name):
  974. self.layer_norm.build([None, None, self.config.hidden_size])
  975. if getattr(self, "layer", None) is not None:
  976. for layer in self.layer:
  977. with tf.name_scope(layer.name):
  978. layer.build(None)
  979. @keras_serializable
  980. class TFHubertMainLayer(keras.layers.Layer):
  981. config_class = HubertConfig
  982. def __init__(self, config: HubertConfig, **kwargs):
  983. super().__init__(**kwargs)
  984. self.config = config
  985. self.feature_extractor = TFHubertFeatureEncoder(config, name="feature_extractor")
  986. self.feature_projection = TFHubertFeatureProjection(config, name="feature_projection")
  987. if config.do_stable_layer_norm:
  988. self.encoder = TFHubertEncoderStableLayerNorm(config, name="encoder")
  989. else:
  990. self.encoder = TFHubertEncoder(config, name="encoder")
  991. def build(self, input_shape=None):
  992. self.masked_spec_embed = self.add_weight(
  993. shape=(self.config.hidden_size,), initializer="uniform", trainable=True, name="masked_spec_embed"
  994. )
  995. if self.built:
  996. return
  997. self.built = True
  998. if getattr(self, "feature_extractor", None) is not None:
  999. with tf.name_scope(self.feature_extractor.name):
  1000. self.feature_extractor.build(None)
  1001. if getattr(self, "feature_projection", None) is not None:
  1002. with tf.name_scope(self.feature_projection.name):
  1003. self.feature_projection.build(None)
  1004. if getattr(self, "encoder", None) is not None:
  1005. with tf.name_scope(self.encoder.name):
  1006. self.encoder.build(None)
  1007. def _get_feat_extract_output_lengths(self, input_lengths: tf.Tensor):
  1008. """
  1009. Computes the output length of the convolutional layers
  1010. """
  1011. def _conv_out_length(input_length, kernel_size, stride):
  1012. # 1D convolutional layer output length formula taken
  1013. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1014. return (input_length - kernel_size) // stride + 1
  1015. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  1016. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  1017. return input_lengths
  1018. def _mask_hidden_states(self, hidden_states: tf.Tensor, mask_time_indices: tf.Tensor | None = None):
  1019. """
  1020. Masks extracted features along time axis and/or along feature axis according to
  1021. [SpecAugment](https://arxiv.org/abs/1904.08779).
  1022. """
  1023. batch_size, sequence_length, hidden_size = shape_list(hidden_states)
  1024. # `config.apply_spec_augment` can set masking to False
  1025. if not getattr(self.config, "apply_spec_augment", True):
  1026. return hidden_states
  1027. if mask_time_indices is not None:
  1028. # apply SpecAugment along time axis with given mask_time_indices
  1029. hidden_states = tf.where(
  1030. tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
  1031. self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
  1032. hidden_states,
  1033. )
  1034. elif self.config.mask_time_prob > 0:
  1035. # generate indices & apply SpecAugment along time axis
  1036. mask_time_indices = _compute_mask_indices(
  1037. (batch_size, sequence_length),
  1038. mask_prob=self.config.mask_time_prob,
  1039. mask_length=self.config.mask_time_length,
  1040. min_masks=2,
  1041. )
  1042. hidden_states = tf.where(
  1043. tf.cast(mask_time_indices[:, :, tf.newaxis], tf.bool),
  1044. self.masked_spec_embed[tf.newaxis, tf.newaxis, :],
  1045. hidden_states,
  1046. )
  1047. # apply SpecAugment along feature axis
  1048. if self.config.mask_feature_prob > 0:
  1049. mask_feature_indices = _compute_mask_indices(
  1050. (batch_size, hidden_size),
  1051. mask_prob=self.config.mask_feature_prob,
  1052. mask_length=self.config.mask_feature_length,
  1053. )
  1054. hidden_states = tf.where(mask_feature_indices[:, tf.newaxis, :], hidden_states, 0)
  1055. return hidden_states
  1056. @unpack_inputs
  1057. def call(
  1058. self,
  1059. input_values: tf.Tensor,
  1060. attention_mask: tf.Tensor | None = None,
  1061. token_type_ids: tf.Tensor | None = None,
  1062. position_ids: tf.Tensor | None = None,
  1063. head_mask: tf.Tensor | None = None,
  1064. inputs_embeds: tf.Tensor | None = None,
  1065. output_attentions: tf.Tensor | None = None,
  1066. output_hidden_states: tf.Tensor | None = None,
  1067. return_dict: Optional[bool] = None,
  1068. training: bool = False,
  1069. **kwargs: Any,
  1070. ):
  1071. hidden_states = self.feature_extractor(tf.cast(input_values, tf.float32), training=training)
  1072. if attention_mask is not None:
  1073. # compute real output lengths according to convolution formula
  1074. output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, -1))
  1075. attention_mask = tf.sequence_mask(
  1076. output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype
  1077. )
  1078. hidden_states = self.feature_projection(hidden_states, training=training)
  1079. mask_time_indices = kwargs.get("mask_time_indices", None)
  1080. if training:
  1081. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  1082. encoder_outputs = self.encoder(
  1083. hidden_states,
  1084. attention_mask=attention_mask,
  1085. output_attentions=output_attentions,
  1086. output_hidden_states=output_hidden_states,
  1087. return_dict=return_dict,
  1088. training=training,
  1089. )
  1090. hidden_states = encoder_outputs[0]
  1091. if not return_dict:
  1092. return (hidden_states,) + encoder_outputs[1:]
  1093. return TFBaseModelOutput(
  1094. last_hidden_state=hidden_states,
  1095. hidden_states=encoder_outputs.hidden_states,
  1096. attentions=encoder_outputs.attentions,
  1097. )
  1098. class TFHubertPreTrainedModel(TFPreTrainedModel):
  1099. """
  1100. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1101. models.
  1102. """
  1103. config_class = HubertConfig
  1104. base_model_prefix = "hubert"
  1105. main_input_name = "input_values"
  1106. @property
  1107. def input_signature(self):
  1108. return {
  1109. "input_values": tf.TensorSpec((None, 16000), tf.float32, name="input_values"),
  1110. "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
  1111. "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
  1112. }
  1113. def __init__(self, config, *inputs, **kwargs):
  1114. super().__init__(config, *inputs, **kwargs)
  1115. logger.warning(
  1116. f"\n{self.__class__.__name__} has backpropagation operations that are NOT supported on CPU. If you wish "
  1117. "to train/fine-tune this model, you need a GPU or a TPU"
  1118. )
  1119. HUBERT_START_DOCSTRING = r"""
  1120. This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the
  1121. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  1122. etc.)
  1123. This model is also a [keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it
  1124. as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and
  1125. behavior.
  1126. <Tip>
  1127. TensorFlow models and layers in `transformers` accept two formats as input:
  1128. - having all inputs as keyword arguments (like PyTorch models), or
  1129. - having all inputs as a list, tuple or dict in the first positional argument.
  1130. The reason the second format is supported is that Keras methods prefer this format when passing inputs to models
  1131. and layers. Because of this support, when using methods like `model.fit()` things should "just work" for you - just
  1132. pass your inputs and labels in any format that `model.fit()` supports! If, however, you want to use the second
  1133. format outside of Keras methods like `fit()` and `predict()`, such as when creating your own layers or models with
  1134. the Keras `Functional` API, there are three possibilities you can use to gather all the input Tensors in the first
  1135. positional argument:
  1136. - a single Tensor with `input_values` only and nothing else: `model(input_values)`
  1137. - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
  1138. `model([input_values, attention_mask])` or `model([input_values, attention_mask, token_type_ids])`
  1139. - a dictionary with one or several input Tensors associated to the input names given in the docstring:
  1140. `model({"input_values": input_values, "token_type_ids": token_type_ids})`
  1141. Note that when creating models and layers with
  1142. [subclassing](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) then you don't need to worry
  1143. about any of this, as you can just pass inputs like you would to any other Python function!
  1144. </Tip>
  1145. Args:
  1146. config ([`HubertConfig`]): Model configuration class with all the parameters of the model.
  1147. Initializing with a config file does not load the weights associated with the model, only the
  1148. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1149. """
  1150. HUBERT_INPUTS_DOCSTRING = r"""
  1151. Args:
  1152. input_values (`np.ndarray`, `tf.Tensor`, `List[tf.Tensor]` `Dict[str, tf.Tensor]` or `Dict[str, np.ndarray]` and each example must have the shape `({0})`):
  1153. Indices of input sequence tokens in the vocabulary.
  1154. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  1155. [`PreTrainedTokenizer.encode`] for details.
  1156. [What are input IDs?](../glossary#input-ids)
  1157. attention_mask (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1158. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1159. - 1 for tokens that are **not masked**,
  1160. - 0 for tokens that are **masked**.
  1161. [What are attention masks?](../glossary#attention-mask)
  1162. token_type_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1163. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1164. 1]`:
  1165. - 0 corresponds to a *sentence A* token,
  1166. - 1 corresponds to a *sentence B* token.
  1167. [What are token type IDs?](../glossary#token-type-ids)
  1168. position_ids (`np.ndarray` or `tf.Tensor` of shape `({0})`, *optional*):
  1169. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1170. config.max_position_embeddings - 1]`.
  1171. [What are position IDs?](../glossary#position-ids)
  1172. head_mask (`np.ndarray` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1173. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  1174. - 1 indicates the head is **not masked**,
  1175. - 0 indicates the head is **masked**.
  1176. inputs_embeds (`np.ndarray` or `tf.Tensor` of shape `({0}, hidden_size)`, *optional*):
  1177. Optionally, instead of passing `input_values` you can choose to directly pass an embedded representation.
  1178. This is useful if you want more control over how to convert `input_values` indices into associated vectors
  1179. than the model's internal embedding lookup matrix.
  1180. output_attentions (`bool`, *optional*):
  1181. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  1182. tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the
  1183. config will be used instead.
  1184. output_hidden_states (`bool`, *optional*):
  1185. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  1186. more detail. This argument can be used only in eager mode, in graph mode the value in the config will be
  1187. used instead.
  1188. return_dict (`bool`, *optional*):
  1189. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in
  1190. eager mode, in graph mode the value will always be set to True.
  1191. training (`bool`, *optional*, defaults to `False``):
  1192. Whether or not to use the model in training mode (some modules like dropout modules have different
  1193. behaviors between training and evaluation).
  1194. """
  1195. @add_start_docstrings(
  1196. "The bare TFHubert Model transformer outputing raw hidden-states without any specific head on top.",
  1197. HUBERT_START_DOCSTRING,
  1198. )
  1199. class TFHubertModel(TFHubertPreTrainedModel):
  1200. def __init__(self, config: HubertConfig, *inputs, **kwargs):
  1201. super().__init__(config, *inputs, **kwargs)
  1202. self.config = config
  1203. self.hubert = TFHubertMainLayer(config, name="hubert")
  1204. @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
  1205. @replace_return_docstrings(output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC)
  1206. @unpack_inputs
  1207. def call(
  1208. self,
  1209. input_values: tf.Tensor,
  1210. attention_mask: tf.Tensor | None = None,
  1211. token_type_ids: tf.Tensor | None = None,
  1212. position_ids: tf.Tensor | None = None,
  1213. head_mask: tf.Tensor | None = None,
  1214. inputs_embeds: tf.Tensor | None = None,
  1215. output_attentions: Optional[bool] = None,
  1216. output_hidden_states: Optional[bool] = None,
  1217. return_dict: Optional[bool] = None,
  1218. training: bool = False,
  1219. ) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
  1220. """
  1221. Returns:
  1222. Example:
  1223. ```python
  1224. >>> from transformers import AutoProcessor, TFHubertModel
  1225. >>> from datasets import load_dataset
  1226. >>> import soundfile as sf
  1227. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  1228. >>> model = TFHubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  1229. >>> def map_to_array(batch):
  1230. ... speech, _ = sf.read(batch["file"])
  1231. ... batch["speech"] = speech
  1232. ... return batch
  1233. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1234. >>> ds = ds.map(map_to_array)
  1235. >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
  1236. >>> hidden_states = model(input_values).last_hidden_state
  1237. ```"""
  1238. output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
  1239. output_attentions = output_attentions if output_attentions else self.config.output_attentions
  1240. return_dict = return_dict if return_dict else self.config.return_dict
  1241. outputs = self.hubert(
  1242. input_values=input_values,
  1243. attention_mask=attention_mask,
  1244. token_type_ids=token_type_ids,
  1245. position_ids=position_ids,
  1246. head_mask=head_mask,
  1247. inputs_embeds=inputs_embeds,
  1248. output_attentions=output_attentions,
  1249. output_hidden_states=output_hidden_states,
  1250. return_dict=return_dict,
  1251. training=training,
  1252. )
  1253. return outputs
  1254. def build(self, input_shape=None):
  1255. if self.built:
  1256. return
  1257. self.built = True
  1258. if getattr(self, "hubert", None) is not None:
  1259. with tf.name_scope(self.hubert.name):
  1260. self.hubert.build(None)
  1261. @add_start_docstrings(
  1262. """TFHubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""",
  1263. HUBERT_START_DOCSTRING,
  1264. )
  1265. class TFHubertForCTC(TFHubertPreTrainedModel):
  1266. def __init__(self, config: HubertConfig, *inputs, **kwargs):
  1267. super().__init__(config, *inputs, **kwargs)
  1268. self.hubert = TFHubertMainLayer(config, name="hubert")
  1269. self.dropout = keras.layers.Dropout(config.final_dropout)
  1270. self.lm_head = keras.layers.Dense(config.vocab_size, name="lm_head")
  1271. self.output_hidden_size = (
  1272. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1273. )
  1274. def freeze_feature_extractor(self):
  1275. """
  1276. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1277. not be updated during training.
  1278. """
  1279. warnings.warn(
  1280. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1281. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1282. FutureWarning,
  1283. )
  1284. self.freeze_feature_encoder()
  1285. def freeze_feature_encoder(self):
  1286. """
  1287. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1288. not be updated during training.
  1289. """
  1290. self.hubert.feature_extractor.trainable = False
  1291. @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
  1292. @replace_return_docstrings(output_type=TFCausalLMOutput, config_class=_CONFIG_FOR_DOC)
  1293. @unpack_inputs
  1294. def call(
  1295. self,
  1296. input_values: tf.Tensor,
  1297. attention_mask: tf.Tensor | None = None,
  1298. token_type_ids: tf.Tensor | None = None,
  1299. position_ids: tf.Tensor | None = None,
  1300. head_mask: tf.Tensor | None = None,
  1301. inputs_embeds: tf.Tensor | None = None,
  1302. output_attentions: Optional[bool] = None,
  1303. labels: tf.Tensor | None = None,
  1304. output_hidden_states: Optional[bool] = None,
  1305. return_dict: Optional[bool] = None,
  1306. training: Optional[bool] = False,
  1307. ) -> Union[TFCausalLMOutput, Tuple[tf.Tensor]]:
  1308. r"""
  1309. labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  1310. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1311. config.vocab_size]` (see `input_values` docstring) Tokens with indices set to `-100` are ignored (masked),
  1312. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1313. Returns:
  1314. Example:
  1315. ```python
  1316. >>> import tensorflow as tf
  1317. >>> from transformers import AutoProcessor, TFHubertForCTC
  1318. >>> from datasets import load_dataset
  1319. >>> import soundfile as sf
  1320. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  1321. >>> model = TFHubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")
  1322. >>> def map_to_array(batch):
  1323. ... speech, _ = sf.read(batch["file"])
  1324. ... batch["speech"] = speech
  1325. ... return batch
  1326. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1327. >>> ds = ds.map(map_to_array)
  1328. >>> input_values = processor(ds["speech"][0], return_tensors="tf").input_values # Batch size 1
  1329. >>> logits = model(input_values).logits
  1330. >>> predicted_ids = tf.argmax(logits, axis=-1)
  1331. >>> transcription = processor.decode(predicted_ids[0])
  1332. >>> # compute loss
  1333. >>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
  1334. >>> # Pass the transcription as text to encode labels
  1335. >>> labels = processor(text=transcription, return_tensors="tf").input_values
  1336. >>> loss = model(input_values, labels=labels).loss
  1337. ```"""
  1338. if labels is not None and tf.reduce_max(labels) >= self.config.vocab_size:
  1339. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1340. outputs = self.hubert(
  1341. input_values=input_values,
  1342. attention_mask=attention_mask,
  1343. token_type_ids=token_type_ids,
  1344. position_ids=position_ids,
  1345. head_mask=head_mask,
  1346. inputs_embeds=inputs_embeds,
  1347. output_attentions=output_attentions,
  1348. output_hidden_states=output_hidden_states,
  1349. return_dict=return_dict,
  1350. training=training,
  1351. )
  1352. hidden_states = outputs[0]
  1353. hidden_states = self.dropout(hidden_states, training=training)
  1354. logits = self.lm_head(hidden_states)
  1355. if labels is not None:
  1356. attention_mask = (
  1357. attention_mask if attention_mask is not None else tf.ones_like(input_values, dtype=tf.float32)
  1358. )
  1359. input_lengths = self.hubert._get_feat_extract_output_lengths(tf.reduce_sum(attention_mask, axis=-1))
  1360. # assuming that padded tokens are filled with -100
  1361. # when not being attended to
  1362. labels_mask = tf.cast(labels >= 0, tf.int32)
  1363. target_lengths = tf.reduce_sum(labels_mask, axis=-1)
  1364. loss = tf.nn.ctc_loss(
  1365. logits=logits,
  1366. labels=labels,
  1367. logit_length=input_lengths,
  1368. label_length=target_lengths,
  1369. blank_index=self.config.pad_token_id,
  1370. logits_time_major=False,
  1371. )
  1372. if self.config.ctc_loss_reduction == "sum":
  1373. loss = tf.reduce_sum(loss)
  1374. loss = tf.reshape(loss, (1,))
  1375. if self.config.ctc_loss_reduction == "mean":
  1376. loss = tf.reduce_mean(loss)
  1377. loss = tf.reshape(loss, (1,))
  1378. else:
  1379. loss = None
  1380. if not return_dict:
  1381. output = (logits,) + outputs[1:]
  1382. return ((loss,) + output) if loss is not None else output
  1383. return TFCausalLMOutput(
  1384. loss=loss,
  1385. logits=logits,
  1386. hidden_states=outputs.hidden_states,
  1387. attentions=outputs.attentions,
  1388. )
  1389. def build(self, input_shape=None):
  1390. if self.built:
  1391. return
  1392. self.built = True
  1393. if getattr(self, "hubert", None) is not None:
  1394. with tf.name_scope(self.hubert.name):
  1395. self.hubert.build(None)
  1396. if getattr(self, "lm_head", None) is not None:
  1397. with tf.name_scope(self.lm_head.name):
  1398. self.lm_head.build([None, None, self.output_hidden_size])