modeling_dac.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. # coding=utf-8
  2. # Copyright 2024 Descript and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Transformers DAC model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import numpy as np
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import (
  25. ModelOutput,
  26. add_start_docstrings,
  27. add_start_docstrings_to_model_forward,
  28. replace_return_docstrings,
  29. )
  30. from .configuration_dac import DacConfig
  31. # General docstring
  32. _CONFIG_FOR_DOC = "DacConfig"
  33. @dataclass
  34. class DacOutput(ModelOutput):
  35. """
  36. Args:
  37. loss (`torch.Tensor`):
  38. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  39. audio_values (`torch.Tensor` of shape `(batch_size, input_length)`):
  40. Reconstructed audio data.
  41. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  42. Quantized continuous representation of input.
  43. audio_codes (`torch.LongTensor` of shape `(batch_size, num_codebooks, time_steps)`):
  44. Codebook indices for each codebook (quantized discrete representation of input).
  45. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  46. Projected latents (continuous representation of input before quantization).
  47. """
  48. loss: torch.FloatTensor = None
  49. audio_values: torch.FloatTensor = None
  50. quantized_representation: torch.FloatTensor = None
  51. audio_codes: torch.LongTensor = None
  52. projected_latents: torch.FloatTensor = None
  53. @dataclass
  54. class DacEncoderOutput(ModelOutput):
  55. """
  56. Args:
  57. loss (`torch.Tensor`):
  58. Loss from the encoder model, comprising the weighted combination of the commitment and codebook losses.
  59. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`, *optional*):
  60. Quantized continuous representation of input.
  61. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  62. Codebook indices for each codebook (quantized discrete representation of input).
  63. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`, *optional*):
  64. Projected latents (continuous representation of input before quantization).
  65. """
  66. loss: torch.FloatTensor = None
  67. quantized_representation: torch.FloatTensor = None
  68. audio_codes: torch.FloatTensor = None
  69. projected_latents: torch.FloatTensor = None
  70. @dataclass
  71. # Copied from transformers.models.encodec.modeling_encodec.EncodecDecoderOutput with Encodec->Dac, segment_length->input_length
  72. class DacDecoderOutput(ModelOutput):
  73. """
  74. Args:
  75. audio_values (`torch.FloatTensor` of shape `(batch_size, input_length)`, *optional*):
  76. Decoded audio values, obtained using the decoder part of Dac.
  77. """
  78. audio_values: torch.FloatTensor = None
  79. class Snake1d(nn.Module):
  80. """
  81. A 1-dimensional Snake activation function module.
  82. """
  83. def __init__(self, hidden_dim):
  84. super().__init__()
  85. self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
  86. def forward(self, hidden_states):
  87. shape = hidden_states.shape
  88. hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
  89. hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
  90. hidden_states = hidden_states.reshape(shape)
  91. return hidden_states
  92. class DacVectorQuantize(nn.Module):
  93. """
  94. Implementation of VQ similar to Karpathy's repo (https://github.com/karpathy/deep-vector-quantization)
  95. Additionally uses following tricks from improved VQGAN
  96. (https://arxiv.org/pdf/2110.04627.pdf):
  97. 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
  98. for improved codebook usage
  99. 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
  100. improves training stability
  101. """
  102. def __init__(self, config: DacConfig):
  103. super().__init__()
  104. self.in_proj = nn.Conv1d(config.hidden_size, config.codebook_dim, kernel_size=1)
  105. self.out_proj = nn.Conv1d(config.codebook_dim, config.hidden_size, kernel_size=1)
  106. self.codebook = nn.Embedding(config.codebook_size, config.codebook_dim)
  107. def forward(self, hidden_state):
  108. """
  109. Quantizes the input tensor using a fixed codebook and returns the corresponding codebook vectors.
  110. Args:
  111. hidden_state (`torch.FloatTensor` of shape `(batch_size, dimension, time_steps)`):
  112. Input tensor.
  113. Returns:
  114. quantized_representation (`torch.Tensor`of shape `(batch_size, dimension, time_steps)`):
  115. Quantized continuous representation of input.
  116. commitment_loss (`torch.FloatTensor`of shape `(1)`):
  117. Commitment loss to train encoder to predict vectors closer to codebook entries.
  118. codebook_loss (`torch.FloatTensor`of shape `(1)`):
  119. Codebook loss to update the codebook.
  120. audio_codes (`torch.LongTensor` of shape `(batch_size, time_steps)`):
  121. Codebook indices for each codebook, quantized discrete representation of input.
  122. projected_latents (torch.FloatTensor of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  123. Projected latents (continuous representation of input before quantization).
  124. """
  125. projected_latents = self.in_proj(hidden_state)
  126. quantized_representation, audio_codes = self.decode_latents(projected_latents)
  127. commitment_loss = F.mse_loss(projected_latents, quantized_representation.detach(), reduction="mean")
  128. codebook_loss = F.mse_loss(quantized_representation, projected_latents.detach(), reduction="mean")
  129. # noop in forward pass, straight-through gradient estimator in backward pass
  130. quantized_representation = projected_latents + (quantized_representation - projected_latents).detach()
  131. quantized_representation = self.out_proj(quantized_representation)
  132. return quantized_representation, commitment_loss, codebook_loss, audio_codes, projected_latents
  133. def decode_latents(self, hidden_states):
  134. batch_size, hidden_dim, sequence_length = hidden_states.shape
  135. encodings = hidden_states.permute(0, 2, 1).reshape(batch_size * sequence_length, hidden_dim)
  136. codebook = self.codebook.weight # codebook: (N x D)
  137. # L2 normalize encodings and codebook (ViT-VQGAN)
  138. encodings = F.normalize(encodings)
  139. codebook = F.normalize(codebook)
  140. # Compute euclidean distance with codebook
  141. l2_norm = encodings.pow(2).sum(1, keepdim=True)
  142. dist = -(l2_norm - 2 * encodings @ codebook.t()) + codebook.pow(2).sum(1, keepdim=True).t()
  143. indices = dist.max(1)[1]
  144. indices = indices.reshape(hidden_states.size(0), -1)
  145. quantized_representation = self.codebook(indices).transpose(1, 2)
  146. return quantized_representation, indices
  147. class DacResidualUnit(nn.Module):
  148. """
  149. A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
  150. """
  151. def __init__(self, dimension: int = 16, dilation: int = 1):
  152. super().__init__()
  153. pad = ((7 - 1) * dilation) // 2
  154. self.snake1 = Snake1d(dimension)
  155. self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
  156. self.snake2 = Snake1d(dimension)
  157. self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
  158. def forward(self, hidden_state):
  159. """
  160. Forward pass through the residual unit.
  161. Args:
  162. hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  163. Input tensor .
  164. Returns:
  165. output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
  166. Input tensor after passing through the residual unit.
  167. """
  168. output_tensor = hidden_state
  169. output_tensor = self.conv1(self.snake1(output_tensor))
  170. output_tensor = self.conv2(self.snake2(output_tensor))
  171. padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
  172. if padding > 0:
  173. hidden_state = hidden_state[..., padding:-padding]
  174. output_tensor = hidden_state + output_tensor
  175. return output_tensor
  176. class DacEncoderBlock(nn.Module):
  177. """Encoder block used in DAC encoder."""
  178. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  179. super().__init__()
  180. dimension = config.encoder_hidden_size * 2**stride_index
  181. self.res_unit1 = DacResidualUnit(dimension // 2, dilation=1)
  182. self.res_unit2 = DacResidualUnit(dimension // 2, dilation=3)
  183. self.res_unit3 = DacResidualUnit(dimension // 2, dilation=9)
  184. self.snake1 = Snake1d(dimension // 2)
  185. self.conv1 = nn.Conv1d(
  186. dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
  187. )
  188. def forward(self, hidden_state):
  189. hidden_state = self.res_unit1(hidden_state)
  190. hidden_state = self.res_unit2(hidden_state)
  191. hidden_state = self.snake1(self.res_unit3(hidden_state))
  192. hidden_state = self.conv1(hidden_state)
  193. return hidden_state
  194. class DacDecoderBlock(nn.Module):
  195. """Decoder block used in DAC decoder."""
  196. def __init__(self, config: DacConfig, stride: int = 1, stride_index: int = 1):
  197. super().__init__()
  198. input_dim = config.decoder_hidden_size // 2**stride_index
  199. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  200. self.snake1 = Snake1d(input_dim)
  201. self.conv_t1 = nn.ConvTranspose1d(
  202. input_dim,
  203. output_dim,
  204. kernel_size=2 * stride,
  205. stride=stride,
  206. padding=math.ceil(stride / 2),
  207. )
  208. self.res_unit1 = DacResidualUnit(output_dim, dilation=1)
  209. self.res_unit2 = DacResidualUnit(output_dim, dilation=3)
  210. self.res_unit3 = DacResidualUnit(output_dim, dilation=9)
  211. def forward(self, hidden_state):
  212. hidden_state = self.snake1(hidden_state)
  213. hidden_state = self.conv_t1(hidden_state)
  214. hidden_state = self.res_unit1(hidden_state)
  215. hidden_state = self.res_unit2(hidden_state)
  216. hidden_state = self.res_unit3(hidden_state)
  217. return hidden_state
  218. class DacResidualVectorQuantize(nn.Module):
  219. """
  220. ResidualVectorQuantize block - Introduced in SoundStream: An end2end neural audio codec (https://arxiv.org/abs/2107.03312)
  221. """
  222. def __init__(self, config: DacConfig):
  223. super().__init__()
  224. n_codebooks = config.n_codebooks
  225. quantizer_dropout = config.quantizer_dropout
  226. self.n_codebooks = n_codebooks
  227. self.quantizers = nn.ModuleList([DacVectorQuantize(config) for i in range(config.n_codebooks)])
  228. self.quantizer_dropout = quantizer_dropout
  229. def forward(self, hidden_state, n_quantizers: int = None):
  230. """
  231. Quantizes the input tensor using a fixed set of codebooks and returns corresponding codebook vectors.
  232. Args:
  233. hidden_state (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  234. Input tensor to be quantized.
  235. n_quantizers (`int`, *optional*):
  236. Number of quantizers to use. If specified and `self.quantizer_dropout` is True,
  237. this argument is ignored during training, and a random number of quantizers is used.
  238. Returns:
  239. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  240. Quantized continuous representation of input.
  241. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  242. Codebook indices for each codebook (quantized discrete representation of input).
  243. projected_latents (`torch.Tensor` of shape `(batch_size, num_codebooks * dimension, time_steps)`):
  244. Projected latents (continuous representation of input before quantization).
  245. commitment_loss (`torch.Tensor` of shape `(1)`):
  246. Commitment loss to train the encoder to predict vectors closer to codebook entries.
  247. codebook_loss (`torch.Tensor` of shape `(1)`):
  248. Codebook loss to update the codebook.
  249. """
  250. quantized_representation = 0
  251. residual = hidden_state
  252. commitment_loss = 0
  253. codebook_loss = 0
  254. audio_codes = []
  255. projected_latents = []
  256. n_quantizers = n_quantizers if n_quantizers is not None else self.n_codebooks
  257. if self.training:
  258. n_quantizers = torch.ones((hidden_state.shape[0],)) * self.n_codebooks + 1
  259. dropout = torch.randint(1, self.n_codebooks + 1, (hidden_state.shape[0],))
  260. n_dropout = int(hidden_state.shape[0] * self.quantizer_dropout)
  261. n_quantizers[:n_dropout] = dropout[:n_dropout]
  262. n_quantizers = n_quantizers.to(hidden_state.device)
  263. for i, quantizer in enumerate(self.quantizers):
  264. if self.training is False and i >= n_quantizers:
  265. break
  266. quantized_representation_i, commitment_loss_i, codebook_loss_i, indices_i, projected_latents_i = quantizer(
  267. residual
  268. )
  269. # Create mask to apply quantizer dropout
  270. mask = torch.full((hidden_state.shape[0],), fill_value=i, device=hidden_state.device) < n_quantizers
  271. quantized_representation = quantized_representation + quantized_representation_i * mask[:, None, None]
  272. residual = residual - quantized_representation_i
  273. # Sum losses
  274. commitment_loss += commitment_loss_i * mask
  275. codebook_loss += codebook_loss_i * mask
  276. audio_codes.append(indices_i)
  277. projected_latents.append(projected_latents_i)
  278. audio_codes = torch.stack(audio_codes, dim=1)
  279. projected_latents = torch.cat(projected_latents, dim=1)
  280. return quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss
  281. def from_codes(self, audio_codes: torch.Tensor):
  282. """
  283. Reconstructs the continuous representation from quantized codes.
  284. Args:
  285. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`):
  286. Quantized discrete representation of input.
  287. Returns:
  288. quantized_representation (`torch.Tensor`):
  289. Quantized continuous representation of input.
  290. projected_latents (`torch.Tensor`):
  291. List of projected latents (continuous representations of input before quantization)
  292. for each codebook.
  293. audio_codes (`torch.Tensor`):
  294. Codebook indices for each codebook.
  295. """
  296. quantized_representation = 0.0
  297. projected_latents = []
  298. n_codebooks = audio_codes.shape[1]
  299. for i in range(n_codebooks):
  300. projected_latents_i = self.quantizers[i].codebook(audio_codes[:, i, :]).transpose(1, 2)
  301. projected_latents.append(projected_latents_i)
  302. quantized_representation += self.quantizers[i].out_proj(projected_latents_i)
  303. return quantized_representation, torch.cat(projected_latents, dim=1), audio_codes
  304. def from_latents(self, latents: torch.Tensor):
  305. """Reconstructs the quantized representation from unquantized latents.
  306. Args:
  307. latents (`torch.Tensor` of shape `(batch_size, total_latent_dimension, time_steps)`):
  308. Continuous representation of input after projection.
  309. Returns:
  310. quantized_representation (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  311. Quantized representation of the full-projected space.
  312. quantized_latents (`torch.Tensor` of shape `(batch_size, dimension, time_steps)`):
  313. Quantized representation of the latent space (continuous representation before quantization).
  314. """
  315. quantized_representation = 0
  316. quantized_latents = []
  317. codes = []
  318. codebook_dims_tensor = torch.tensor([0] + [q.codebook_dim for q in self.quantizers])
  319. dims = torch.cumsum(codebook_dims_tensor, dim=0)
  320. n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
  321. for i in range(n_codebooks):
  322. hidden_dim_j, hidden_dim_k = dims[i], dims[i + 1]
  323. quantized_latents_i, codes_i = self.quantizers[i].decode_latents(latents[:, hidden_dim_j:hidden_dim_k, :])
  324. quantized_latents.append(quantized_latents_i)
  325. codes.append(codes_i)
  326. quantized_representation_i = self.quantizers[i].out_proj(quantized_latents_i)
  327. quantized_representation = quantized_representation + quantized_representation_i
  328. return quantized_representation, torch.cat(quantized_latents, dim=1)
  329. class DacDecoder(nn.Module):
  330. """DAC Decoder"""
  331. def __init__(self, config: DacConfig):
  332. super().__init__()
  333. input_channel = config.hidden_size
  334. channels = config.decoder_hidden_size
  335. strides = config.upsampling_ratios
  336. # Add first conv layer
  337. self.conv1 = nn.Conv1d(input_channel, channels, kernel_size=7, padding=3)
  338. # Add upsampling + MRF blocks
  339. block = []
  340. for stride_index, stride in enumerate(strides):
  341. block += [DacDecoderBlock(config, stride, stride_index)]
  342. self.block = nn.ModuleList(block)
  343. output_dim = config.decoder_hidden_size // 2 ** (stride_index + 1)
  344. self.snake1 = Snake1d(output_dim)
  345. self.conv2 = nn.Conv1d(output_dim, 1, kernel_size=7, padding=3)
  346. self.tanh = nn.Tanh()
  347. def forward(self, hidden_state):
  348. hidden_state = self.conv1(hidden_state)
  349. for layer in self.block:
  350. hidden_state = layer(hidden_state)
  351. hidden_state = self.snake1(hidden_state)
  352. hidden_state = self.conv2(hidden_state)
  353. hidden_state = self.tanh(hidden_state)
  354. return hidden_state
  355. class DacEncoder(nn.Module):
  356. """DAC Encoder"""
  357. def __init__(self, config: DacConfig):
  358. super().__init__()
  359. strides = config.downsampling_ratios
  360. # Create first convolution
  361. self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
  362. self.block = []
  363. # Create EncoderBlocks that double channels as they downsample by `stride`
  364. for stride_index, stride in enumerate(strides):
  365. stride_index = stride_index + 1
  366. self.block += [DacEncoderBlock(config, stride=stride, stride_index=stride_index)]
  367. self.block = nn.ModuleList(self.block)
  368. d_model = config.encoder_hidden_size * 2**stride_index
  369. self.snake1 = Snake1d(d_model)
  370. self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
  371. def forward(self, hidden_state):
  372. hidden_state = self.conv1(hidden_state)
  373. for module in self.block:
  374. hidden_state = module(hidden_state)
  375. hidden_state = self.snake1(hidden_state)
  376. hidden_state = self.conv2(hidden_state)
  377. return hidden_state
  378. class DacPreTrainedModel(PreTrainedModel):
  379. """
  380. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
  381. """
  382. config_class = DacConfig
  383. base_model_prefix = "dac"
  384. main_input_name = "input_values"
  385. def _init_weights(self, module):
  386. if isinstance(module, nn.Conv1d):
  387. nn.init.trunc_normal_(module.weight, std=0.02)
  388. nn.init.constant_(module.bias, 0)
  389. def apply_weight_norm(self):
  390. weight_norm = nn.utils.weight_norm
  391. if hasattr(nn.utils.parametrizations, "weight_norm"):
  392. weight_norm = nn.utils.parametrizations.weight_norm
  393. for layer in self.quantizer.quantizers:
  394. weight_norm(layer.in_proj)
  395. weight_norm(layer.out_proj)
  396. weight_norm(self.encoder.conv1)
  397. weight_norm(self.encoder.conv2)
  398. for layer in self.encoder.block:
  399. weight_norm(layer.conv1)
  400. weight_norm(layer.res_unit1.conv1)
  401. weight_norm(layer.res_unit1.conv2)
  402. weight_norm(layer.res_unit2.conv1)
  403. weight_norm(layer.res_unit2.conv2)
  404. weight_norm(layer.res_unit3.conv1)
  405. weight_norm(layer.res_unit3.conv2)
  406. weight_norm(self.decoder.conv1)
  407. weight_norm(self.decoder.conv2)
  408. for layer in self.decoder.block:
  409. weight_norm(layer.conv_t1)
  410. weight_norm(layer.res_unit1.conv1)
  411. weight_norm(layer.res_unit1.conv2)
  412. weight_norm(layer.res_unit2.conv1)
  413. weight_norm(layer.res_unit2.conv2)
  414. weight_norm(layer.res_unit3.conv1)
  415. weight_norm(layer.res_unit3.conv2)
  416. def remove_weight_norm(self):
  417. for layer in self.quantizer.quantizers:
  418. nn.utils.remove_weight_norm(layer.in_proj)
  419. nn.utils.remove_weight_norm(layer.out_proj)
  420. nn.utils.remove_weight_norm(self.encoder.conv1)
  421. nn.utils.remove_weight_norm(self.encoder.conv2)
  422. for layer in self.encoder.block:
  423. nn.utils.remove_weight_norm(layer.conv1)
  424. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  425. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  426. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  427. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  428. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  429. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  430. nn.utils.remove_weight_norm(self.decoder.conv1)
  431. nn.utils.remove_weight_norm(self.decoder.conv2)
  432. for layer in self.decoder.block:
  433. nn.utils.remove_weight_norm(layer.conv_t1)
  434. nn.utils.remove_weight_norm(layer.res_unit1.conv1)
  435. nn.utils.remove_weight_norm(layer.res_unit1.conv2)
  436. nn.utils.remove_weight_norm(layer.res_unit2.conv1)
  437. nn.utils.remove_weight_norm(layer.res_unit2.conv2)
  438. nn.utils.remove_weight_norm(layer.res_unit3.conv1)
  439. nn.utils.remove_weight_norm(layer.res_unit3.conv2)
  440. DAC_START_DOCSTRING = r"""
  441. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  442. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  443. etc.)
  444. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  445. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  446. and behavior.
  447. Parameters:
  448. config ([`DacConfig`]):
  449. Model configuration class with all the parameters of the model. Initializing with a config file does not
  450. load the weights associated with the model, only the configuration. Check out the
  451. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  452. """
  453. DAC_INPUTS_DOCSTRING = r"""
  454. Args:
  455. input_values (`torch.Tensor` of shape `(batch_size, 1, time_steps)`).
  456. Audio data to encode,
  457. n_quantizers (`int`, *optional*):
  458. Number of quantizers to use. If `None`, all quantizers are used. Default is `None`.
  459. return_dict (`bool`, *optional*):
  460. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  461. """
  462. @add_start_docstrings(
  463. "The DAC (Descript Audio Codec) model.",
  464. DAC_START_DOCSTRING,
  465. )
  466. class DacModel(DacPreTrainedModel):
  467. def __init__(self, config: DacConfig):
  468. super().__init__(config)
  469. self.config = config
  470. self.encoder = DacEncoder(config)
  471. self.decoder = DacDecoder(config)
  472. self.quantizer = DacResidualVectorQuantize(config)
  473. self.bits_per_codebook = int(math.log2(self.config.codebook_size))
  474. if 2**self.bits_per_codebook != self.config.codebook_size:
  475. raise ValueError("The codebook_size must be a power of 2.")
  476. # Initialize weights and apply final processing
  477. self.post_init()
  478. @replace_return_docstrings(output_type=DacEncoderOutput, config_class=_CONFIG_FOR_DOC)
  479. def encode(
  480. self,
  481. input_values: torch.Tensor,
  482. n_quantizers: int = None,
  483. return_dict: Optional[bool] = None,
  484. ):
  485. """
  486. Encode given audio data and return quantized latent codes
  487. Args:
  488. input_values (`torch.Tensor of shape `(batch_size, 1, time_steps)`):
  489. Input audio data to encode,
  490. n_quantizers (int, *optional*):
  491. Number of quantizers to use. If None, all quantizers are used. Default is None.
  492. return_dict (`bool`, *optional*):
  493. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  494. Returns:
  495. """
  496. return_dict = return_dict if return_dict is not None else self.config.return_dict
  497. quantized_representation = self.encoder(input_values)
  498. quantized_representation, audio_codes, projected_latents, commitment_loss, codebook_loss = self.quantizer(
  499. quantized_representation, n_quantizers
  500. )
  501. loss = self.config.commitment_loss_weight * commitment_loss + self.config.codebook_loss_weight * codebook_loss
  502. if not return_dict:
  503. return (loss, quantized_representation, audio_codes, projected_latents)
  504. return DacEncoderOutput(loss, quantized_representation, audio_codes, projected_latents)
  505. @replace_return_docstrings(output_type=DacDecoderOutput, config_class=_CONFIG_FOR_DOC)
  506. def decode(
  507. self,
  508. quantized_representation: Optional[torch.Tensor] = None,
  509. audio_codes: Optional[torch.Tensor] = None,
  510. return_dict: Optional[bool] = None,
  511. ):
  512. """Decode given latent codes and return audio data
  513. Args:
  514. quantized_representation (torch.Tensor of shape `(batch_size, dimension, time_steps)`, *optional*):
  515. Quantized continuous representation of input.
  516. audio_codes (`torch.Tensor` of shape `(batch_size, num_codebooks, time_steps)`, *optional*):
  517. The codebook indices for each codebook, representing the quantized discrete
  518. representation of the input. This parameter should be provided if you want
  519. to decode directly from the audio codes (it will overwrite quantized_representation).
  520. return_dict (`bool`, *optional*):
  521. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  522. Returns:
  523. """
  524. if quantized_representation is None and audio_codes is None:
  525. raise ValueError("Either `quantized_representation` or `audio_codes` must be provided.")
  526. return_dict = return_dict if return_dict is not None else self.config.return_dict
  527. if audio_codes is not None:
  528. quantized_representation = self.quantizer.from_codes(audio_codes)[0]
  529. audio_values = self.decoder(quantized_representation).squeeze(1)
  530. if not return_dict:
  531. return (audio_values,)
  532. return DacDecoderOutput(audio_values)
  533. @add_start_docstrings_to_model_forward(DAC_INPUTS_DOCSTRING)
  534. @replace_return_docstrings(output_type=DacOutput, config_class=_CONFIG_FOR_DOC)
  535. def forward(
  536. self,
  537. input_values: torch.Tensor,
  538. n_quantizers: int = None,
  539. return_dict: Optional[bool] = None,
  540. ):
  541. """
  542. Returns:
  543. Examples:
  544. ```python
  545. >>> from datasets import load_dataset, Audio
  546. >>> from transformers import DacModel, AutoProcessor
  547. >>> librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  548. >>> model = DacModel.from_pretrained("descript/dac_16khz")
  549. >>> processor = AutoProcessor.from_pretrained("descript/dac_16khz")
  550. >>> librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
  551. >>> audio_sample = librispeech_dummy[-1]["audio"]["array"]
  552. >>> inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
  553. >>> encoder_outputs = model.encode(inputs["input_values"])
  554. >>> # Get the intermediate audio codes
  555. >>> audio_codes = encoder_outputs.audio_codes
  556. >>> # Reconstruct the audio from its quantized representation
  557. >>> audio_values = model.decode(encoder_outputs.quantized_representation)
  558. >>> # or the equivalent with a forward pass
  559. >>> audio_values = model(inputs["input_values"]).audio_values
  560. ```"""
  561. return_dict = return_dict if return_dict is not None else self.config.return_dict
  562. length = input_values.shape[-1]
  563. loss, quantized_representation, audio_codes, projected_latents = self.encode(
  564. input_values, n_quantizers, return_dict=False
  565. )
  566. audio_values = self.decode(quantized_representation, return_dict=False)[0][..., :length]
  567. if not return_dict:
  568. return (loss, audio_values, quantized_representation, audio_codes, projected_latents)
  569. return DacOutput(loss, audio_values, quantized_representation, audio_codes, projected_latents)