modeling_bark.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821
  1. # coding=utf-8
  2. # Copyright 2023 The Suno AI Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BARK model."""
  16. import math
  17. from typing import Dict, Optional, Tuple, Union
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import functional as F
  22. from ...generation import GenerationMixin
  23. from ...generation.logits_process import (
  24. AlternatingCodebooksLogitsProcessor,
  25. BarkEosPrioritizerLogitsProcessor,
  26. SuppressTokensLogitsProcessor,
  27. )
  28. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  29. from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
  30. from ...modeling_utils import PreTrainedModel, get_parameter_device
  31. from ...utils import (
  32. add_start_docstrings,
  33. add_start_docstrings_to_model_forward,
  34. is_accelerate_available,
  35. is_flash_attn_2_available,
  36. is_flash_attn_greater_or_equal_2_10,
  37. logging,
  38. )
  39. from ..auto import AutoModel
  40. from .configuration_bark import (
  41. BarkCoarseConfig,
  42. BarkConfig,
  43. BarkFineConfig,
  44. BarkSemanticConfig,
  45. BarkSubModelConfig,
  46. )
  47. from .generation_configuration_bark import (
  48. BarkCoarseGenerationConfig,
  49. BarkFineGenerationConfig,
  50. BarkSemanticGenerationConfig,
  51. )
  52. if is_flash_attn_2_available():
  53. from ...modeling_flash_attention_utils import _flash_attention_forward
  54. logger = logging.get_logger(__name__)
  55. _CHECKPOINT_FOR_DOC = "suno/bark-small"
  56. _CONFIG_FOR_DOC = "BarkConfig"
  57. class BarkSelfAttention(nn.Module):
  58. # adapted from GPTNeoSelfAttention and Bark code
  59. # BarkSelfAttention can have two attention type, i.e full attention or causal attention
  60. def __init__(self, config, is_causal=False):
  61. super().__init__()
  62. # regularization
  63. self.dropout = config.dropout
  64. self.attn_dropout = nn.Dropout(config.dropout)
  65. self.resid_dropout = nn.Dropout(config.dropout)
  66. self.embed_dim = config.hidden_size
  67. self.num_heads = config.num_heads
  68. self.head_dim = self.embed_dim // self.num_heads
  69. if config.hidden_size % config.num_heads != 0:
  70. raise ValueError(
  71. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  72. f" {self.num_heads})."
  73. )
  74. # key, query, value projections for all heads, but in a batch
  75. self.att_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=config.bias)
  76. # output projection
  77. self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.bias)
  78. self.is_causal = is_causal
  79. if is_causal:
  80. block_size = config.block_size
  81. bias = torch.tril(torch.ones((block_size, block_size), dtype=bool)).view(1, 1, block_size, block_size)
  82. self.register_buffer("bias", bias)
  83. # Copied from transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoSelfAttention._split_heads
  84. def _split_heads(self, tensor, num_heads, attn_head_size):
  85. """
  86. Splits hidden_size dim into attn_head_size and num_heads
  87. """
  88. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  89. tensor = tensor.view(new_shape)
  90. return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
  91. def _merge_heads(self, tensor, num_heads, attn_head_size):
  92. """
  93. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  94. """
  95. # re-assemble all head outputs side by side
  96. # (batch, num_heads, seq_len, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  97. tensor = tensor.transpose(1, 2).contiguous()
  98. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  99. return tensor
  100. def _attn(self, query, key, value, attention_mask=None, head_mask=None):
  101. # unlike GPTNeo's SelfAttention, divide by the square root of the dimension of the query and the key
  102. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * (1.0 / math.sqrt(self.head_dim))
  103. if self.is_causal:
  104. query_length, key_length = query.size(-2), key.size(-2)
  105. # fill the upper left part of the attention weights with inf
  106. attn_weights = attn_weights.masked_fill(
  107. self.bias[:, :, key_length - query_length : key_length, :key_length] == 0,
  108. torch.finfo(attn_weights.dtype).min,
  109. )
  110. if attention_mask is not None:
  111. # Apply the attention mask
  112. attn_weights = attn_weights + attention_mask
  113. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  114. attn_weights = attn_weights.to(value.dtype)
  115. attn_weights = self.attn_dropout(attn_weights)
  116. # Mask heads if we want to
  117. if head_mask is not None:
  118. attn_weights = attn_weights * head_mask
  119. # (batch, num_heads, seq_len, seq_len) x (batch, num_heads, seq_len, attn_head_size)
  120. # -> (batch, num_heads, seq_len, attn_head_size)
  121. attn_output = torch.matmul(attn_weights, value)
  122. return attn_output, attn_weights
  123. def forward(
  124. self,
  125. hidden_states,
  126. attention_mask=None,
  127. past_key_values=None,
  128. head_mask=None,
  129. use_cache=False,
  130. output_attentions=False,
  131. ):
  132. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  133. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  134. query = self._split_heads(query, self.num_heads, self.head_dim)
  135. key = self._split_heads(key, self.num_heads, self.head_dim)
  136. value = self._split_heads(value, self.num_heads, self.head_dim)
  137. if past_key_values is not None:
  138. past_key = past_key_values[0]
  139. past_value = past_key_values[1]
  140. key = torch.cat((past_key, key), dim=-2)
  141. value = torch.cat((past_value, value), dim=-2)
  142. if use_cache is True:
  143. present = (key, value)
  144. else:
  145. present = None
  146. attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  147. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  148. attn_output = self.out_proj(attn_output)
  149. attn_output = self.resid_dropout(attn_output)
  150. outputs = (attn_output, present)
  151. if output_attentions:
  152. outputs += (attn_weights,)
  153. return outputs
  154. class BarkSelfFlashAttention2(BarkSelfAttention):
  155. """
  156. Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
  157. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  158. flash attention and deal with padding tokens in case the input contains any of them.
  159. """
  160. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  161. def __init__(self, *args, **kwargs):
  162. super().__init__(*args, **kwargs)
  163. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  164. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  165. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  166. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  167. def _split_heads(self, tensor, num_heads, attn_head_size):
  168. """
  169. Splits hidden_size dim into attn_head_size and num_heads
  170. """
  171. new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
  172. tensor = tensor.view(new_shape)
  173. # Flash attention requires the input to have the shape
  174. # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
  175. return tensor
  176. def _merge_heads(self, tensor, num_heads, attn_head_size):
  177. """
  178. Merges attn_head_size dim and num_attn_heads dim into hidden_size
  179. """
  180. # re-assemble all head outputs side by side
  181. # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
  182. tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
  183. return tensor
  184. def forward(
  185. self,
  186. hidden_states,
  187. attention_mask=None,
  188. past_key_values=None,
  189. head_mask=None,
  190. use_cache=False,
  191. output_attentions=False,
  192. ):
  193. batch_size, query_len, _ = hidden_states.size()
  194. # calculate query, key, values for all heads in batch and move head forward to be the batch dim
  195. query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
  196. query = self._split_heads(query, self.num_heads, self.head_dim)
  197. key = self._split_heads(key, self.num_heads, self.head_dim)
  198. value = self._split_heads(value, self.num_heads, self.head_dim)
  199. if past_key_values is not None:
  200. # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
  201. past_key = past_key_values[0].transpose(1, 2)
  202. past_value = past_key_values[1].transpose(1, 2)
  203. # and merge on seq_length
  204. key = torch.cat((past_key, key), dim=1)
  205. value = torch.cat((past_value, value), dim=1)
  206. if use_cache is True:
  207. # (batch, head, seq_length, head_features)
  208. present = (key.transpose(1, 2), value.transpose(1, 2))
  209. else:
  210. present = None
  211. attn_output = _flash_attention_forward(
  212. query,
  213. key,
  214. value,
  215. attention_mask,
  216. query_len,
  217. dropout=self.dropout if self.training else 0.0,
  218. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  219. is_causal=self.is_causal,
  220. )
  221. attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
  222. attn_output = self.out_proj(attn_output)
  223. attn_output = self.resid_dropout(attn_output)
  224. outputs = (attn_output, present)
  225. if output_attentions:
  226. attn_weights = None
  227. outputs += (attn_weights,)
  228. return outputs
  229. BARK_ATTENTION_CLASSES = {
  230. "eager": BarkSelfAttention,
  231. "flash_attention_2": BarkSelfFlashAttention2,
  232. }
  233. class BarkLayerNorm(nn.Module):
  234. """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""
  235. def __init__(self, hidden_size, bias=True):
  236. super().__init__()
  237. self.weight = nn.Parameter(torch.ones(hidden_size))
  238. self.bias = nn.Parameter(torch.zeros(hidden_size)) if bias else None
  239. def forward(self, input):
  240. return F.layer_norm(input, self.weight.shape, self.weight, self.bias, eps=1e-5)
  241. class BarkMLP(nn.Module):
  242. def __init__(self, config):
  243. super().__init__()
  244. self.in_proj = nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=config.bias)
  245. self.out_proj = nn.Linear(4 * config.hidden_size, config.hidden_size, bias=config.bias)
  246. self.dropout = nn.Dropout(config.dropout)
  247. self.gelu = nn.GELU()
  248. def forward(self, hidden_states):
  249. hidden_states = self.in_proj(hidden_states)
  250. hidden_states = self.gelu(hidden_states)
  251. hidden_states = self.out_proj(hidden_states)
  252. hidden_states = self.dropout(hidden_states)
  253. return hidden_states
  254. class BarkBlock(nn.Module):
  255. def __init__(self, config, is_causal=False):
  256. super().__init__()
  257. if is_causal:
  258. # if causal, uses handmade LayerNorm, so that the layerNorm bias is optional
  259. # this handmade layerNorm is used to stick with Bark choice of leaving optional bias in
  260. # AutoRegressive models (corresponding to the "Text" and the "Coarse" modules)
  261. self.layernorm_1 = BarkLayerNorm(config.hidden_size, bias=config.bias)
  262. self.layernorm_2 = BarkLayerNorm(config.hidden_size, bias=config.bias)
  263. else:
  264. self.layernorm_1 = nn.LayerNorm(config.hidden_size)
  265. self.layernorm_2 = nn.LayerNorm(config.hidden_size)
  266. self.attn = BARK_ATTENTION_CLASSES[config._attn_implementation](config, is_causal=is_causal)
  267. self.mlp = BarkMLP(config)
  268. def forward(
  269. self,
  270. hidden_states,
  271. past_key_values=None,
  272. attention_mask=None,
  273. head_mask=None,
  274. use_cache=False,
  275. output_attentions=False,
  276. ):
  277. intermediary_hidden_states = self.layernorm_1(hidden_states)
  278. attn_outputs = self.attn(
  279. intermediary_hidden_states,
  280. past_key_values=past_key_values,
  281. attention_mask=attention_mask,
  282. head_mask=head_mask,
  283. use_cache=use_cache,
  284. output_attentions=output_attentions,
  285. )
  286. attn_output = attn_outputs[0] # output_attn: output, present_key_values, (attn_weights)
  287. outputs = attn_outputs[1:]
  288. intermediary_hidden_states = hidden_states + attn_output
  289. intermediary_hidden_states = intermediary_hidden_states + self.mlp(
  290. self.layernorm_2(intermediary_hidden_states)
  291. )
  292. if use_cache:
  293. outputs = (intermediary_hidden_states,) + outputs
  294. else:
  295. outputs = (intermediary_hidden_states,) + outputs[1:]
  296. return outputs # hidden_states, ((present), attentions)
  297. class BarkPreTrainedModel(PreTrainedModel):
  298. """
  299. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  300. models.
  301. """
  302. config_class = BarkConfig
  303. supports_gradient_checkpointing = False
  304. _supports_flash_attn_2 = True
  305. def _init_weights(self, module):
  306. """Initialize the weights."""
  307. if isinstance(module, (nn.Linear,)):
  308. # Slightly different from the TF version which uses truncated_normal for initialization
  309. # cf https://github.com/pytorch/pytorch/pull/5617
  310. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  311. if module.bias is not None:
  312. module.bias.data.zero_()
  313. elif isinstance(module, nn.Embedding):
  314. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  315. if module.padding_idx is not None:
  316. module.weight.data[module.padding_idx].zero_()
  317. elif isinstance(module, nn.LayerNorm):
  318. module.bias.data.zero_()
  319. module.weight.data.fill_(1.0)
  320. def __init__(self, *inputs, **kwargs):
  321. super().__init__(*inputs, **kwargs)
  322. @property
  323. def device(self) -> torch.device:
  324. """
  325. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  326. device).
  327. """
  328. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  329. if not hasattr(self, "_hf_hook"):
  330. return get_parameter_device(self)
  331. for module in self.modules():
  332. if (
  333. hasattr(module, "_hf_hook")
  334. and hasattr(module._hf_hook, "execution_device")
  335. and module._hf_hook.execution_device is not None
  336. ):
  337. return torch.device(module._hf_hook.execution_device)
  338. return get_parameter_device(self)
  339. BARK_MODEL_START_DOCSTRING = """
  340. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  341. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  342. etc.)
  343. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  344. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  345. and behavior.
  346. Parameters:
  347. config ([`{config}`]):
  348. Model configuration class with all the parameters of the model. Initializing with a config file does not
  349. load the weights associated with the model, only the configuration. Check out the
  350. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  351. """
  352. BARK_START_DOCSTRING = r"""
  353. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  354. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  355. etc.)
  356. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  357. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  358. and behavior.
  359. Parameters:
  360. config ([`BarkConfig`]):
  361. Model configuration class with all the parameters of the model. Initializing with a config file does not
  362. load the weights associated with the model, only the configuration. Check out the
  363. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  364. """
  365. BARK_FINE_INPUTS_DOCSTRING = r"""
  366. Args:
  367. codebook_idx (`int`):
  368. Index of the codebook that will be predicted.
  369. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, number_of_codebooks)`):
  370. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  371. it. Initially, indices of the first two codebooks are obtained from the `coarse` sub-model. The rest is
  372. predicted recursively by attending the previously predicted channels. The model predicts on windows of
  373. length 1024.
  374. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  375. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  376. - 1 for tokens that are **not masked**,
  377. - 0 for tokens that are **masked**.
  378. [What are attention masks?](../glossary#attention-mask)
  379. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  380. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  381. config.max_position_embeddings - 1]`.
  382. [What are position IDs?](../glossary#position-ids)
  383. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  384. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  385. - 1 indicates the head is **not masked**,
  386. - 0 indicates the head is **masked**.
  387. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): NOT IMPLEMENTED YET.
  388. input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
  389. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. If
  390. `past_key_values` is used, optionally only the last `input_embeds` have to be input (see
  391. `past_key_values`). This is useful if you want more control over how to convert `input_ids` indices into
  392. associated vectors than the model's internal embedding lookup matrix.
  393. output_attentions (`bool`, *optional*):
  394. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  395. tensors for more detail.
  396. output_hidden_states (`bool`, *optional*):
  397. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  398. more detail.
  399. return_dict (`bool`, *optional*):
  400. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  401. """
  402. BARK_CAUSAL_MODEL_INPUTS_DOCSTRING = r"""
  403. Args:
  404. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  405. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  406. it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  407. [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
  408. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache` is passed or when `config.use_cache=True`):
  409. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  410. `(batch_size, num_heads, sequence_length, embed_size_per_head)`.
  411. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  412. `past_key_values` input) to speed up sequential decoding.
  413. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  414. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  415. `input_ids` of shape `(batch_size, sequence_length)`.
  416. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  417. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  418. - 1 for tokens that are **not masked**,
  419. - 0 for tokens that are **masked**.
  420. [What are attention masks?](../glossary#attention-mask)
  421. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  422. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  423. config.max_position_embeddings - 1]`.
  424. [What are position IDs?](../glossary#position-ids)
  425. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  426. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  427. - 1 indicates the head is **not masked**,
  428. - 0 indicates the head is **masked**.
  429. input_embeds (`torch.FloatTensor` of shape `(batch_size, input_sequence_length, hidden_size)`, *optional*):
  430. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  431. Here, due to `Bark` particularities, if `past_key_values` is used, `input_embeds` will be ignored and you
  432. have to use `input_ids`. If `past_key_values` is not used and `use_cache` is set to `True`, `input_embeds`
  433. is used in priority instead of `input_ids`.
  434. use_cache (`bool`, *optional*):
  435. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  436. `past_key_values`).
  437. output_attentions (`bool`, *optional*):
  438. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  439. tensors for more detail.
  440. output_hidden_states (`bool`, *optional*):
  441. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  442. more detail.
  443. return_dict (`bool`, *optional*):
  444. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  445. """
  446. # GPT2-like autoregressive model
  447. class BarkCausalModel(BarkPreTrainedModel, GenerationMixin):
  448. config_class = BarkSubModelConfig
  449. def __init__(self, config):
  450. super().__init__(config)
  451. self.config = config
  452. # initialize as an autoregressive GPT-like model
  453. self.input_embeds_layer = nn.Embedding(config.input_vocab_size, config.hidden_size)
  454. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  455. self.drop = nn.Dropout(config.dropout)
  456. self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
  457. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  458. self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
  459. self.lm_head = nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  460. self.gradient_checkpointing = False
  461. # Initialize weights and apply final processing
  462. self.post_init()
  463. def get_input_embeddings(self):
  464. return self.input_embeds_layer
  465. def set_input_embeddings(self, new_embeddings):
  466. self.input_embeds_layer = new_embeddings
  467. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
  468. # Overwritten -- bark has a model-specific hack
  469. input_embeds = kwargs.get("input_embeds", None)
  470. attention_mask = kwargs.get("attention_mask", None)
  471. position_ids = kwargs.get("position_ids", None)
  472. if past_key_values is not None:
  473. # Omit tokens covered by past_key_values
  474. seq_len = input_ids.shape[1]
  475. past_length = past_key_values[0][0].shape[2]
  476. # Some generation methods already pass only the last input ID
  477. if input_ids.shape[1] > past_length:
  478. remove_prefix_length = past_length
  479. else:
  480. # Default to old behavior: keep only final ID
  481. remove_prefix_length = input_ids.shape[1] - 1
  482. input_ids = input_ids[:, remove_prefix_length:]
  483. # input_embeds have already been used and is not required anymore
  484. input_embeds = None
  485. else:
  486. if input_embeds is not None and kwargs.get("use_cache"):
  487. seq_len = input_embeds.shape[1]
  488. else:
  489. seq_len = input_ids.shape[1]
  490. # ensure that attention_mask and position_ids shapes are aligned with the weird Bark hack of reducing
  491. # sequence length on the first forward pass
  492. if attention_mask is not None:
  493. attention_mask = attention_mask[:, :seq_len]
  494. if position_ids is not None:
  495. position_ids = position_ids[:, :seq_len]
  496. if attention_mask is not None and position_ids is None:
  497. # create position_ids on the fly for batch generation
  498. position_ids = attention_mask.long().cumsum(-1) - 1
  499. position_ids.masked_fill_(attention_mask == 0, 1)
  500. if past_key_values:
  501. position_ids = position_ids[:, -input_ids.shape[1] :]
  502. else:
  503. position_ids = None
  504. if input_embeds is not None and kwargs.get("use_cache"):
  505. return {
  506. "input_ids": None,
  507. "input_embeds": input_embeds,
  508. "past_key_values": past_key_values,
  509. "use_cache": kwargs.get("use_cache"),
  510. "position_ids": position_ids,
  511. "attention_mask": attention_mask,
  512. }
  513. return {
  514. "input_ids": input_ids,
  515. "past_key_values": past_key_values,
  516. "use_cache": kwargs.get("use_cache"),
  517. "position_ids": position_ids,
  518. "attention_mask": attention_mask,
  519. }
  520. @add_start_docstrings_to_model_forward(BARK_CAUSAL_MODEL_INPUTS_DOCSTRING)
  521. def forward(
  522. self,
  523. input_ids: Optional[torch.Tensor] = None,
  524. past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
  525. attention_mask: Optional[torch.Tensor] = None,
  526. position_ids: Optional[torch.Tensor] = None,
  527. head_mask: Optional[torch.Tensor] = None,
  528. labels: Optional[torch.LongTensor] = None,
  529. input_embeds: Optional[torch.Tensor] = None,
  530. use_cache: Optional[bool] = None,
  531. output_attentions: Optional[bool] = None,
  532. output_hidden_states: Optional[bool] = None,
  533. return_dict: Optional[bool] = None,
  534. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
  535. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  536. output_hidden_states = (
  537. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  538. )
  539. use_cache = use_cache if use_cache is not None else self.config.use_cache
  540. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  541. loss = None
  542. if labels is not None:
  543. raise NotImplementedError(
  544. "Training is not implemented yet for Bark - ensure you do not pass `labels` to the model."
  545. )
  546. # Verify if input_embeds already exists
  547. # then compute embeddings.
  548. if input_ids is not None and input_embeds is not None:
  549. raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
  550. elif input_embeds is not None and past_key_values is None:
  551. # we want to return the input_embeds in priority so that it is in line with a weird hack
  552. # of Bark which concatenate two bits of the input_embeds on the first forward pass of the semantic model
  553. pass
  554. elif input_ids is not None:
  555. input_embeds = self.input_embeds_layer(input_ids) # token embeddings of shape (b, t, n_embd)
  556. elif input_embeds is not None:
  557. pass
  558. else:
  559. raise ValueError("You have to specify either input_ids or input_embeds")
  560. input_shape = input_embeds.size()[:-1]
  561. batch_size = input_embeds.shape[0]
  562. seq_length = input_shape[-1]
  563. device = input_ids.device if input_ids is not None else input_embeds.device
  564. if past_key_values is None:
  565. past_length = 0
  566. past_key_values = tuple([None] * len(self.layers))
  567. else:
  568. past_length = past_key_values[0][0].size(-2)
  569. if position_ids is None:
  570. position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
  571. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  572. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  573. # Attention mask.
  574. if attention_mask is not None:
  575. if batch_size <= 0:
  576. raise ValueError("batch_size has to be defined and > 0")
  577. if self._use_flash_attention_2:
  578. attention_mask = attention_mask if 0 in attention_mask else None
  579. else:
  580. attention_mask = attention_mask.view(batch_size, -1)
  581. # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
  582. # from_seq_length is 1 to easily broadcast
  583. attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
  584. # Prepare head mask if needed
  585. # 1.0 in head_mask indicate we keep the head
  586. # attention_probs has shape bsz x num_heads x N x N
  587. # head_mask has shape num_layers x batch x num_heads x N x N
  588. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  589. hidden_states = self.drop(input_embeds + position_embeds)
  590. output_shape = input_shape + (hidden_states.size(-1),)
  591. if self.gradient_checkpointing and self.training:
  592. if use_cache:
  593. logger.warning_once(
  594. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  595. )
  596. use_cache = False
  597. present_key_values = () if use_cache else None
  598. all_self_attentions = () if output_attentions else None
  599. all_hidden_states = () if output_hidden_states else None
  600. for i, (block, past_layer_key_values) in enumerate(zip(self.layers, past_key_values)):
  601. if output_hidden_states:
  602. all_hidden_states = all_hidden_states + (hidden_states,)
  603. if self.gradient_checkpointing and self.training:
  604. outputs = self._gradient_checkpointing_func(
  605. block.__call__,
  606. hidden_states,
  607. None,
  608. attention_mask,
  609. head_mask[i],
  610. use_cache,
  611. output_attentions,
  612. )
  613. else:
  614. outputs = block(
  615. hidden_states,
  616. past_key_values=past_layer_key_values,
  617. attention_mask=attention_mask,
  618. head_mask=head_mask[i],
  619. use_cache=use_cache,
  620. output_attentions=output_attentions,
  621. )
  622. hidden_states = outputs[0]
  623. if use_cache:
  624. present_key_values = present_key_values + (outputs[1],)
  625. if output_attentions:
  626. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  627. hidden_states = self.layernorm_final(hidden_states)
  628. hidden_states = hidden_states.view(output_shape)
  629. # Add last hidden state
  630. if output_hidden_states:
  631. all_hidden_states = all_hidden_states + (hidden_states,)
  632. logits = self.lm_head(hidden_states)
  633. if not return_dict:
  634. return tuple(
  635. v for v in [None, logits, present_key_values, all_hidden_states, all_self_attentions] if v is not None
  636. )
  637. return CausalLMOutputWithPast(
  638. loss=loss,
  639. logits=logits,
  640. past_key_values=present_key_values,
  641. hidden_states=all_hidden_states,
  642. attentions=all_self_attentions,
  643. )
  644. @staticmethod
  645. def _reorder_cache(
  646. past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
  647. ) -> Tuple[Tuple[torch.Tensor]]:
  648. """
  649. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  650. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  651. beam_idx at every generation step.
  652. """
  653. # Necessary for beam_search
  654. return tuple(
  655. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  656. for layer_past in past_key_values
  657. )
  658. @add_start_docstrings(
  659. """Bark semantic (or text) model. It shares the same architecture as the coarse model.
  660. It is a GPT-2 like autoregressive model with a language modeling head on top.""",
  661. BARK_MODEL_START_DOCSTRING.format(config="BarkSemanticConfig"),
  662. )
  663. class BarkSemanticModel(BarkCausalModel):
  664. base_model_prefix = "semantic"
  665. config_class = BarkSemanticConfig
  666. def generate(
  667. self,
  668. input_ids: torch.Tensor,
  669. semantic_generation_config: BarkSemanticGenerationConfig = None,
  670. history_prompt: Optional[Dict[str, torch.Tensor]] = None,
  671. attention_mask: Optional[torch.Tensor] = None,
  672. **kwargs,
  673. ) -> torch.LongTensor:
  674. """
  675. Generates text semantic tokens from an input prompt and an additional optional `Bark` speaker prompt.
  676. Args:
  677. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  678. Input ids, i.e tokenized input sentences. Will be truncated up to
  679. semantic_generation_config.max_input_semantic_length tokens. Note that the output audios will be as
  680. long as the longest generation among the batch.
  681. semantic_generation_config (`BarkSemanticGenerationConfig`):
  682. Generation config indicating how to generate the semantic tokens.
  683. history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
  684. Optional `Bark` speaker prompt.
  685. attention_mask (`Optional[torch.Tensor]`, *optional*):
  686. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  687. - 1 for tokens that are **not masked**,
  688. - 0 for tokens that are **masked**.
  689. [What are attention masks?](../glossary#attention-mask)
  690. Returns:
  691. torch.LongTensor: Output semantic tokens.
  692. """
  693. if semantic_generation_config is None:
  694. raise ValueError("`semantic_generation_config` has to be provided")
  695. batch_size = input_ids.shape[0]
  696. max_input_semantic_length = semantic_generation_config.max_input_semantic_length
  697. input_ids = input_ids + semantic_generation_config.text_encoding_offset
  698. if attention_mask is not None:
  699. input_ids = input_ids.masked_fill((1 - attention_mask).bool(), semantic_generation_config.text_pad_token)
  700. if history_prompt is not None:
  701. semantic_history = history_prompt["semantic_prompt"][-max_input_semantic_length:]
  702. semantic_history = nn.functional.pad(
  703. semantic_history,
  704. (0, max_input_semantic_length - len(semantic_history)),
  705. value=semantic_generation_config.semantic_pad_token,
  706. mode="constant",
  707. )
  708. else:
  709. semantic_history = torch.tensor(
  710. [semantic_generation_config.semantic_pad_token] * max_input_semantic_length, dtype=torch.int
  711. ).to(self.device)
  712. semantic_history = torch.repeat_interleave(semantic_history[None], batch_size, dim=0)
  713. infer_array = torch.tensor(
  714. [[semantic_generation_config.semantic_infer_token]] * batch_size, dtype=torch.int
  715. ).to(self.device)
  716. input_embeds = torch.cat(
  717. [
  718. self.input_embeds_layer(input_ids[:, :max_input_semantic_length])
  719. + self.input_embeds_layer(semantic_history[:, : max_input_semantic_length + 1]),
  720. self.input_embeds_layer(infer_array),
  721. ],
  722. dim=1,
  723. )
  724. tokens_to_suppress = list(
  725. range(semantic_generation_config.semantic_vocab_size, semantic_generation_config.semantic_pad_token)
  726. )
  727. tokens_to_suppress.extend(
  728. list(range(semantic_generation_config.semantic_pad_token + 1, self.config.output_vocab_size))
  729. )
  730. suppress_tokens_logits_processor = SuppressTokensLogitsProcessor(tokens_to_suppress, device=input_ids.device)
  731. min_eos_p = kwargs.get("min_eos_p", semantic_generation_config.min_eos_p)
  732. early_stopping_logits_processor = BarkEosPrioritizerLogitsProcessor(
  733. eos_token_id=semantic_generation_config.eos_token_id, min_eos_p=min_eos_p, device=input_ids.device
  734. )
  735. # pass input_ids in order to stay consistent with the transformers generate method even though it is not used
  736. # (except to get the input seq_len - that's why we keep the first 257 tokens)
  737. semantic_output = super().generate(
  738. torch.ones((batch_size, max_input_semantic_length + 1), dtype=torch.int).to(self.device),
  739. input_embeds=input_embeds,
  740. logits_processor=[suppress_tokens_logits_processor, early_stopping_logits_processor],
  741. generation_config=semantic_generation_config,
  742. **kwargs,
  743. ) # size: 10048
  744. # take the generated semantic tokens
  745. semantic_output = semantic_output[:, max_input_semantic_length + 1 :]
  746. return semantic_output
  747. @add_start_docstrings(
  748. """Bark coarse acoustics model.
  749. It shares the same architecture as the semantic (or text) model. It is a GPT-2 like autoregressive model with a
  750. language modeling head on top.""",
  751. BARK_MODEL_START_DOCSTRING.format(config="BarkCoarseConfig"),
  752. )
  753. class BarkCoarseModel(BarkCausalModel):
  754. base_model_prefix = "coarse_acoustics"
  755. config_class = BarkCoarseConfig
  756. def preprocess_histories(
  757. self,
  758. max_coarse_history: int,
  759. semantic_to_coarse_ratio: int,
  760. batch_size: int,
  761. semantic_generation_config: int,
  762. codebook_size: int,
  763. history_prompt: Optional[Dict[str, torch.Tensor]] = None,
  764. ):
  765. """
  766. Preprocess the optional `Bark` speaker prompts before `self.generate`.
  767. Args:
  768. max_coarse_history (`int`):
  769. Maximum size of coarse tokens used.
  770. semantic_to_coarse_ratio (`int`):
  771. Ratio of semantic to coarse frequency
  772. batch_size (`int`):
  773. Batch size, i.e the number of samples.
  774. semantic_generation_config (`BarkSemanticGenerationConfig`):
  775. Generation config indicating how to generate the semantic tokens.
  776. codebook_size (`int`):
  777. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  778. history_prompt (`Optional[Dict[str,torch.Tensor]]`):
  779. Optional `Bark` speaker prompt.
  780. Returns: Returns:
  781. `tuple(torch.FloatTensor)`:
  782. - **x_semantic_history** (`torch.FloatTensor` -- Processed semantic speaker prompt.
  783. - **x_coarse_history** (`torch.FloatTensor`) -- Processed coarse speaker prompt.
  784. """
  785. if history_prompt is not None:
  786. x_semantic_history = torch.repeat_interleave(history_prompt["semantic_prompt"][None], batch_size, dim=0)
  787. # clone to avoid modifying history_prompt.coarse_prompt
  788. x_coarse_history = history_prompt["coarse_prompt"].clone()
  789. # offset x_coarse_history
  790. if codebook_size is not None:
  791. for n in range(1, x_coarse_history.shape[0]):
  792. # offset
  793. x_coarse_history[n, :] += codebook_size * n
  794. # flatten x_coarse_history
  795. x_coarse_history = torch.transpose(x_coarse_history, 0, 1).reshape(-1)
  796. x_coarse_history = x_coarse_history + semantic_generation_config.semantic_vocab_size
  797. x_coarse_history = torch.repeat_interleave(x_coarse_history[None], batch_size, dim=0)
  798. # e.g: after SEMANTIC_VOCAB_SIZE (10000), 1024 tokens dedicated to first codebook, 1024 next tokens
  799. # dedicated to second codebook.
  800. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  801. # trim histories correctly
  802. n_semantic_hist_provided = min(
  803. [
  804. max_semantic_history,
  805. x_semantic_history.shape[1] - x_semantic_history.shape[1] % 2,
  806. int(np.floor(x_coarse_history.shape[1] / semantic_to_coarse_ratio)),
  807. ]
  808. )
  809. n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
  810. x_semantic_history = x_semantic_history[:, -n_semantic_hist_provided:].int()
  811. x_coarse_history = x_coarse_history[:, -n_coarse_hist_provided:].int()
  812. # bit of a hack for time alignment (sounds better) - from Bark original implementation
  813. x_coarse_history = x_coarse_history[:, :-2]
  814. else:
  815. # shape: (batch_size, 0)
  816. x_semantic_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
  817. x_coarse_history = torch.tensor([[]] * batch_size, dtype=torch.int).to(self.device)
  818. return x_semantic_history, x_coarse_history
  819. def generate(
  820. self,
  821. semantic_output: torch.Tensor,
  822. semantic_generation_config: BarkSemanticGenerationConfig = None,
  823. coarse_generation_config: BarkCoarseGenerationConfig = None,
  824. codebook_size: int = 1024,
  825. history_prompt: Optional[Dict[str, torch.Tensor]] = None,
  826. return_output_lengths: Optional[bool] = None,
  827. **kwargs,
  828. ) -> Union[torch.LongTensor, Tuple[torch.LongTensor, torch.LongTensor]]:
  829. """
  830. Generates coarse acoustics tokens from input text semantic tokens and an additional optional `Bark` speaker
  831. prompt.
  832. Args:
  833. semantic_output (`torch.Tensor` of shape (batch_size, seq_len), *optional*):
  834. Input text semantic ids, i.e the output of `BarkSemanticModel.generate`.
  835. semantic_generation_config (`BarkSemanticGenerationConfig`):
  836. Generation config indicating how to generate the semantic tokens.
  837. coarse_generation_config (`BarkCoarseGenerationConfig`):
  838. Generation config indicating how to generate the coarse tokens.
  839. codebook_size (`int`, *optional*, defaults to 1024):
  840. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  841. history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
  842. Optional `Bark` speaker prompt.
  843. return_output_lengths (`bool`, *optional*):
  844. Whether or not to return the output lengths. Useful when batching.
  845. Returns:
  846. By default:
  847. torch.LongTensor: Output coarse acoustics tokens.
  848. If `return_output_lengths=True`:
  849. `Tuple(torch.Tensor, torch.Tensor): The output coarse acoustics tokens, and the length of each sample
  850. of the batch.
  851. """
  852. if semantic_generation_config is None:
  853. raise ValueError("`semantic_generation_config` has to be provided")
  854. if coarse_generation_config is None:
  855. raise ValueError("`coarse_generation_config` has to be provided")
  856. max_coarse_input_length = coarse_generation_config.max_coarse_input_length
  857. max_coarse_history = coarse_generation_config.max_coarse_history
  858. sliding_window_len = coarse_generation_config.sliding_window_len
  859. # replace semantic_pad_token (eos_tok and pad_tok here) with coarse_semantic_pad_token i.e the pad_token
  860. # used in the next model
  861. semantic_output.masked_fill_(
  862. semantic_output == semantic_generation_config.semantic_pad_token,
  863. coarse_generation_config.coarse_semantic_pad_token,
  864. )
  865. semantic_to_coarse_ratio = (
  866. coarse_generation_config.coarse_rate_hz
  867. / semantic_generation_config.semantic_rate_hz
  868. * coarse_generation_config.n_coarse_codebooks
  869. )
  870. max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
  871. output_lengths = (semantic_output != coarse_generation_config.coarse_semantic_pad_token).sum(1)
  872. output_lengths = torch.floor(
  873. output_lengths * semantic_to_coarse_ratio / coarse_generation_config.n_coarse_codebooks
  874. )
  875. output_lengths = torch.round(output_lengths * coarse_generation_config.n_coarse_codebooks).int()
  876. max_generated_len = torch.max(output_lengths).item()
  877. batch_size = semantic_output.shape[0]
  878. x_semantic_history, x_coarse = self.preprocess_histories(
  879. history_prompt=history_prompt,
  880. max_coarse_history=max_coarse_history,
  881. semantic_to_coarse_ratio=semantic_to_coarse_ratio,
  882. batch_size=batch_size,
  883. semantic_generation_config=semantic_generation_config,
  884. codebook_size=codebook_size,
  885. )
  886. base_semantic_idx = x_semantic_history.shape[1]
  887. semantic_output = torch.hstack([x_semantic_history, semantic_output])
  888. n_window_steps = int(np.ceil(max_generated_len / sliding_window_len))
  889. total_generated_len = 0
  890. len_coarse_history = x_coarse.shape[1]
  891. for _ in range(n_window_steps):
  892. semantic_idx = base_semantic_idx + int(round(total_generated_len / semantic_to_coarse_ratio))
  893. # pad from right side
  894. input_coarse = semantic_output[:, np.max([0, semantic_idx - max_semantic_history]) :]
  895. input_coarse = input_coarse[:, :max_coarse_input_length]
  896. input_coarse = F.pad(
  897. input_coarse,
  898. (0, max_coarse_input_length - input_coarse.shape[-1]),
  899. "constant",
  900. coarse_generation_config.coarse_semantic_pad_token,
  901. )
  902. input_coarse = torch.hstack(
  903. [
  904. input_coarse,
  905. torch.tensor([[coarse_generation_config.coarse_infer_token]] * batch_size).to(self.device),
  906. x_coarse[:, -max_coarse_history:],
  907. ]
  908. )
  909. alternatingLogitsProcessor = AlternatingCodebooksLogitsProcessor(
  910. input_coarse.shape[1],
  911. semantic_generation_config.semantic_vocab_size,
  912. codebook_size,
  913. )
  914. output_coarse = super().generate(
  915. input_coarse,
  916. logits_processor=[alternatingLogitsProcessor],
  917. max_new_tokens=min(sliding_window_len, max_generated_len - total_generated_len),
  918. generation_config=coarse_generation_config,
  919. **kwargs,
  920. )
  921. input_coarse_len = input_coarse.shape[1]
  922. x_coarse = torch.hstack([x_coarse, output_coarse[:, input_coarse_len:]])
  923. total_generated_len = x_coarse.shape[1] - len_coarse_history
  924. del output_coarse
  925. coarse_output = x_coarse[:, len_coarse_history:]
  926. if return_output_lengths:
  927. return coarse_output, output_lengths
  928. return coarse_output
  929. @add_start_docstrings(
  930. """Bark fine acoustics model. It is a non-causal GPT-like model with `config.n_codes_total` embedding layers and
  931. language modeling heads, one for each codebook.""",
  932. BARK_MODEL_START_DOCSTRING.format(config="BarkFineConfig"),
  933. )
  934. class BarkFineModel(BarkPreTrainedModel):
  935. base_model_prefix = "fine_acoustics"
  936. config_class = BarkFineConfig
  937. main_input_name = "codebook_idx"
  938. def __init__(self, config):
  939. # non-causal gpt-like model with one embedding layer and one lm_head for each codebook of Encodec
  940. super().__init__(config)
  941. self.config = config
  942. # initialize a modified non causal GPT-like model
  943. # note that for there is one embedding layer and one lm_head for each codebook of Encodec
  944. self.input_embeds_layers = nn.ModuleList(
  945. [nn.Embedding(config.input_vocab_size, config.hidden_size) for _ in range(config.n_codes_total)]
  946. )
  947. self.position_embeds_layer = nn.Embedding(config.block_size, config.hidden_size)
  948. self.drop = nn.Dropout(config.dropout)
  949. self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
  950. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  951. self.layernorm_final = nn.LayerNorm(config.hidden_size)
  952. self.lm_heads = nn.ModuleList(
  953. [
  954. nn.Linear(config.hidden_size, config.output_vocab_size, bias=False)
  955. for _ in range(config.n_codes_given, config.n_codes_total)
  956. ]
  957. )
  958. self.gradient_checkpointing = False
  959. self.n_codes_total = config.n_codes_total
  960. # Initialize weights and apply final processing
  961. self.post_init()
  962. def get_input_embeddings(self):
  963. # one embedding layers for each codebook
  964. return self.input_embeds_layers
  965. def set_input_embeddings(self, new_embeddings):
  966. # one embedding layers for each codebook
  967. self.input_embeds_layers = new_embeddings
  968. def get_output_embeddings(self):
  969. # one lm_head for each codebook
  970. return self.lm_heads
  971. def set_output_embeddings(self, new_output_embeddings):
  972. # one lm_head for each codebook
  973. self.lm_heads = new_output_embeddings
  974. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
  975. old_embeddings_list = self.get_input_embeddings()
  976. new_embeddings_list = nn.ModuleList(
  977. [
  978. self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
  979. for old_embeddings in old_embeddings_list
  980. ]
  981. )
  982. self.set_input_embeddings(new_embeddings_list)
  983. new_num_tokens = new_embeddings_list[0].weight.shape[0]
  984. # if word embeddings are not tied, make sure that lm head is resized as well
  985. if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
  986. old_lm_head_list = self.get_output_embeddings()
  987. new_lm_head_list = nn.ModuleList(
  988. [self._get_resized_lm_head(old_lm_head, new_num_tokens) for old_lm_head in old_lm_head_list]
  989. )
  990. self.set_output_embeddings(new_lm_head_list)
  991. return self.get_input_embeddings()
  992. def resize_token_embeddings(
  993. self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
  994. ) -> nn.Embedding:
  995. """
  996. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  997. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  998. Arguments:
  999. new_num_tokens (`int`, *optional*):
  1000. The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
  1001. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  1002. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  1003. pad_to_multiple_of (`int`, *optional*):
  1004. If set will pad the embedding matrix to a multiple of the provided value.
  1005. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  1006. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  1007. details about this, or help on choosing the correct value for resizing, refer to this guide:
  1008. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  1009. Return:
  1010. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  1011. """
  1012. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
  1013. if new_num_tokens is None and pad_to_multiple_of is None:
  1014. return model_embeds
  1015. # Update base model and current model config
  1016. self.config.output_vocab_size = model_embeds[0].weight.shape[0]
  1017. self.config.vocab_size = model_embeds[0].weight.shape[0]
  1018. self.output_vocab_size = model_embeds[0].weight.shape[0]
  1019. self.vocab_size = model_embeds[0].weight.shape[0]
  1020. # Tie weights again if needed
  1021. self.tie_weights()
  1022. return model_embeds
  1023. def _tie_weights(self):
  1024. if getattr(self.config, "tie_word_embeddings", True):
  1025. self._tied_weights_keys = []
  1026. output_embeddings = self.get_output_embeddings()
  1027. input_embeddings = self.get_input_embeddings()
  1028. for i in range(self.config.n_codes_total - self.config.n_codes_given):
  1029. # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
  1030. self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
  1031. self._tied_weights_keys.append(f"lm_heads.{i}.weight")
  1032. def tie_weights(self):
  1033. """
  1034. Tie the weights between the input embeddings list and the output embeddings list.
  1035. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  1036. weights instead.
  1037. """
  1038. if getattr(self.config, "tie_word_embeddings", True):
  1039. self._tied_weights_keys = []
  1040. output_embeddings = self.get_output_embeddings()
  1041. input_embeddings = self.get_input_embeddings()
  1042. for i in range(self.config.n_codes_total - self.config.n_codes_given):
  1043. # self.input_embeds_layers[i + 1].weight = self.lm_heads[i].weight
  1044. self._tie_or_clone_weights(output_embeddings[i], input_embeddings[i + 1])
  1045. self._tied_weights_keys.append(f"lm_heads.{i}.weight")
  1046. for module in self.modules():
  1047. if hasattr(module, "_tie_weights"):
  1048. module._tie_weights()
  1049. @add_start_docstrings_to_model_forward(BARK_FINE_INPUTS_DOCSTRING)
  1050. def forward(
  1051. self,
  1052. codebook_idx: int, # an additionnal idx corresponding to the id of the codebook that will be predicted
  1053. input_ids: Optional[torch.Tensor] = None,
  1054. attention_mask: Optional[torch.Tensor] = None,
  1055. position_ids: Optional[torch.Tensor] = None,
  1056. head_mask: Optional[torch.Tensor] = None,
  1057. labels: Optional[torch.LongTensor] = None,
  1058. input_embeds: Optional[torch.Tensor] = None,
  1059. output_attentions: Optional[bool] = None,
  1060. output_hidden_states: Optional[bool] = None,
  1061. return_dict: Optional[bool] = None,
  1062. ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
  1063. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1064. output_hidden_states = (
  1065. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1066. )
  1067. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1068. loss = None
  1069. if labels is not None:
  1070. raise NotImplementedError("Training is not implemented yet")
  1071. if codebook_idx == 0:
  1072. raise ValueError("Cannot predict 0th codebook - 0th codebook should be predicted by the coarse model")
  1073. if input_ids is not None and input_embeds is not None:
  1074. raise ValueError("You cannot specify both input_ids and input_embeds at the same time")
  1075. if input_ids is None and input_embeds is None:
  1076. raise ValueError("You have to specify either input_ids or input_embeds")
  1077. if input_ids is not None:
  1078. # the input_embeddings are the sum of the j previous codebooks embeddings before
  1079. # the current codebook_idx codebook
  1080. # forward the GPT model itself
  1081. input_embeds = [
  1082. input_embeds_layer(input_ids[:, :, i]).unsqueeze(-1)
  1083. for i, input_embeds_layer in enumerate(self.input_embeds_layers)
  1084. ] # token embeddings of shape (b, t, n_embd)
  1085. input_embeds = torch.cat(input_embeds, dim=-1)
  1086. input_embeds = input_embeds[:, :, :, : codebook_idx + 1].sum(dim=-1)
  1087. input_shape = input_embeds.size()[:-1]
  1088. batch_size = input_embeds.shape[0]
  1089. seq_length = input_shape[1]
  1090. device = input_ids.device if input_ids is not None else input_embeds.device
  1091. if position_ids is None:
  1092. position_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)
  1093. position_ids = position_ids.unsqueeze(0) # shape (1, seq_length)
  1094. position_embeds = self.position_embeds_layer(position_ids) # position embeddings of shape (1, t, n_embd)
  1095. # Attention mask.
  1096. if attention_mask is not None:
  1097. if batch_size <= 0:
  1098. raise ValueError("batch_size has to be defined and > 0")
  1099. if self._use_flash_attention_2:
  1100. attention_mask = attention_mask if 0 in attention_mask else None
  1101. else:
  1102. # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
  1103. # from_seq_length is 1 to easily broadcast
  1104. attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)
  1105. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  1106. hidden_states = self.drop(input_embeds + position_embeds)
  1107. output_shape = input_shape + (hidden_states.size(-1),)
  1108. all_self_attentions = () if output_attentions else None
  1109. all_hidden_states = () if output_hidden_states else None
  1110. for i, block in enumerate(self.layers):
  1111. if output_hidden_states:
  1112. all_hidden_states = all_hidden_states + (hidden_states,)
  1113. outputs = block(
  1114. hidden_states,
  1115. attention_mask=attention_mask,
  1116. head_mask=head_mask[i],
  1117. output_attentions=output_attentions,
  1118. )
  1119. hidden_states = outputs[0]
  1120. if output_attentions:
  1121. all_self_attentions = all_self_attentions + (outputs[1],)
  1122. hidden_states = self.layernorm_final(hidden_states)
  1123. hidden_states = hidden_states.view(output_shape)
  1124. # Add last hidden state
  1125. if output_hidden_states:
  1126. all_hidden_states = all_hidden_states + (hidden_states,)
  1127. logits = self.lm_heads[codebook_idx - self.config.n_codes_given](hidden_states)
  1128. if not return_dict:
  1129. return tuple(v for v in [None, logits, all_hidden_states, all_self_attentions] if v is not None)
  1130. return MaskedLMOutput(
  1131. loss=loss,
  1132. logits=logits,
  1133. hidden_states=all_hidden_states,
  1134. attentions=all_self_attentions,
  1135. )
  1136. def generate(
  1137. self,
  1138. coarse_output: torch.Tensor,
  1139. semantic_generation_config: BarkSemanticGenerationConfig = None,
  1140. coarse_generation_config: BarkCoarseGenerationConfig = None,
  1141. fine_generation_config: BarkFineGenerationConfig = None,
  1142. codebook_size: int = 1024,
  1143. history_prompt: Optional[Dict[str, torch.Tensor]] = None,
  1144. **kwargs,
  1145. ) -> torch.LongTensor:
  1146. """
  1147. Generates fine acoustics tokens from input coarse acoustics tokens and an additional optional `Bark` speaker
  1148. prompt.
  1149. Args:
  1150. coarse_output (`torch.Tensor` of shape (batch_size, seq_len)):
  1151. Input coarse acoustics ids, i.e the output of `BarkCoarseModel.generate`.
  1152. semantic_generation_config (`BarkSemanticGenerationConfig`):
  1153. Generation config indicating how to generate the semantic tokens.
  1154. coarse_generation_config (`BarkCoarseGenerationConfig`):
  1155. Generation config indicating how to generate the coarse tokens.
  1156. fine_generation_config (`BarkFineGenerationConfig`):
  1157. Generation config indicating how to generate the fine tokens.
  1158. codebook_size (`int`, *optional*, defaults to 1024):
  1159. Codebook channel size, i.e. the size of the output vocabulary per codebook channel.
  1160. history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
  1161. Optional `Bark` speaker prompt.
  1162. Returns:
  1163. torch.LongTensor: Output fine acoustics tokens.
  1164. """
  1165. if semantic_generation_config is None:
  1166. raise ValueError("`semantic_generation_config` has to be provided")
  1167. if coarse_generation_config is None:
  1168. raise ValueError("`coarse_generation_config` has to be provided")
  1169. if fine_generation_config is None:
  1170. raise ValueError("`fine_generation_config` has to be provided")
  1171. # since we don't really use GenerationConfig through the fine model (autoencoder)
  1172. # and since only temperature is used from the classic GenerationConfig parameters
  1173. # manually impose the kwargs priority over the generation config
  1174. temperature = kwargs.get("temperature", fine_generation_config.temperature)
  1175. max_fine_history_length = fine_generation_config.max_fine_history_length
  1176. max_fine_input_length = fine_generation_config.max_fine_input_length
  1177. # shape: (batch, n_coarse_codebooks * seq_len)
  1178. # new_shape: (batch, seq_len, n_coarse_codebooks)
  1179. coarse_output = coarse_output.view(coarse_output.shape[0], -1, coarse_generation_config.n_coarse_codebooks)
  1180. # brings ids into the range [0, codebook_size -1]
  1181. coarse_output = torch.remainder(coarse_output - semantic_generation_config.semantic_vocab_size, codebook_size)
  1182. batch_size = coarse_output.shape[0]
  1183. if history_prompt is not None:
  1184. x_fine_history = torch.repeat_interleave(history_prompt["fine_prompt"].T[None], batch_size, dim=0)
  1185. # transpose to get to shape (seq_len, n_fine_codebooks)
  1186. else:
  1187. x_fine_history = None
  1188. n_coarse = coarse_generation_config.n_coarse_codebooks
  1189. # pad the last 6th codebooks
  1190. fine_input = F.pad(
  1191. coarse_output,
  1192. (0, fine_generation_config.n_fine_codebooks - n_coarse),
  1193. "constant",
  1194. codebook_size,
  1195. )
  1196. # prepend history if available (max max_fine_history_length)
  1197. if x_fine_history is not None:
  1198. fine_input = torch.cat([x_fine_history[:, -max_fine_history_length:, :], fine_input], dim=1)
  1199. # len of the fine_history that has been added to fine_input
  1200. n_history = x_fine_history[:, -max_fine_history_length:, :].shape[1]
  1201. else:
  1202. n_history = 0
  1203. n_remove_from_end = 0
  1204. # need to pad if too short (since non-causal model)
  1205. if fine_input.shape[1] < max_fine_input_length:
  1206. n_remove_from_end = max_fine_input_length - fine_input.shape[1]
  1207. fine_input = F.pad(fine_input, (0, 0, 0, n_remove_from_end), mode="constant", value=codebook_size)
  1208. # we can be lazy about fractional loop and just keep overwriting codebooks.
  1209. # seems that coarse_output.shape[1] - (max_fine_input_length - n_history) is equal to minus n_remove_from_end
  1210. # So if we needed to pad because too short, n_loops is always 1 (because n_remove_from_end > 0)
  1211. # If not, we loop over at least twice.
  1212. n_loops = (coarse_output.shape[1] - (max_fine_input_length - n_history)) / max_fine_history_length
  1213. n_loops = int(np.ceil(n_loops))
  1214. n_loops = max(0, n_loops) + 1
  1215. for n_outer in range(n_loops):
  1216. start_idx = min([n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_input_length])
  1217. start_fill_idx = min(
  1218. [n_history + n_outer * max_fine_history_length, fine_input.shape[1] - max_fine_history_length]
  1219. )
  1220. rel_start_fill_idx = start_fill_idx - start_idx
  1221. input_buffer = fine_input[:, start_idx : start_idx + max_fine_input_length, :]
  1222. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  1223. logits = self.forward(n_inner, input_buffer).logits
  1224. if temperature is None or temperature == 1.0:
  1225. relevant_logits = logits[:, rel_start_fill_idx:, :codebook_size]
  1226. codebook_preds = torch.argmax(relevant_logits, -1)
  1227. else:
  1228. relevant_logits = logits[:, :, :codebook_size] / temperature
  1229. # apply softmax
  1230. probs = F.softmax(relevant_logits, dim=-1)[:, rel_start_fill_idx:max_fine_input_length]
  1231. # reshape to 2D: (batch_size, seq_len, codebook_size) -> (batch_size*seq_len, codebook_size)
  1232. probs = probs.reshape((-1, codebook_size))
  1233. # multinomial then reshape : (batch_size*seq_len)-> (batch_size,seq_len)
  1234. codebook_preds = torch.multinomial(probs, num_samples=1).view(batch_size, -1)
  1235. codebook_preds = codebook_preds.to(torch.int32)
  1236. input_buffer[:, rel_start_fill_idx:, n_inner] = codebook_preds
  1237. del logits, codebook_preds
  1238. # transfer into fine_input
  1239. for n_inner in range(n_coarse, fine_generation_config.n_fine_codebooks):
  1240. fine_input[
  1241. :, start_fill_idx : start_fill_idx + (max_fine_input_length - rel_start_fill_idx), n_inner
  1242. ] = input_buffer[:, rel_start_fill_idx:, n_inner]
  1243. del input_buffer
  1244. fine_input = fine_input.transpose(1, 2)[:, :, n_history:]
  1245. if n_remove_from_end > 0:
  1246. fine_input = fine_input[:, :, :-n_remove_from_end]
  1247. if fine_input.shape[-1] != coarse_output.shape[-2]:
  1248. raise ValueError("input and output should have the same seq_len")
  1249. return fine_input
  1250. @add_start_docstrings(
  1251. """
  1252. The full Bark model, a text-to-speech model composed of 4 sub-models:
  1253. - [`BarkSemanticModel`] (also referred to as the 'text' model): a causal auto-regressive transformer model that
  1254. takes
  1255. as input tokenized text, and predicts semantic text tokens that capture the meaning of the text.
  1256. - [`BarkCoarseModel`] (also refered to as the 'coarse acoustics' model), also a causal autoregressive transformer,
  1257. that takes into input the results of the last model. It aims at regressing the first two audio codebooks necessary
  1258. to `encodec`.
  1259. - [`BarkFineModel`] (the 'fine acoustics' model), this time a non-causal autoencoder transformer, which iteratively
  1260. predicts the last codebooks based on the sum of the previous codebooks embeddings.
  1261. - having predicted all the codebook channels from the [`EncodecModel`], Bark uses it to decode the output audio
  1262. array.
  1263. It should be noted that each of the first three modules can support conditional speaker embeddings to condition the
  1264. output sound according to specific predefined voice.
  1265. """,
  1266. BARK_START_DOCSTRING,
  1267. )
  1268. class BarkModel(BarkPreTrainedModel):
  1269. config_class = BarkConfig
  1270. def __init__(self, config):
  1271. super().__init__(config)
  1272. self.semantic = BarkSemanticModel(config.semantic_config)
  1273. self.coarse_acoustics = BarkCoarseModel(config.coarse_acoustics_config)
  1274. self.fine_acoustics = BarkFineModel(config.fine_acoustics_config)
  1275. self.codec_model = AutoModel.from_config(config.codec_config)
  1276. self.config = config
  1277. @property
  1278. def device(self) -> torch.device:
  1279. """
  1280. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  1281. device).
  1282. """
  1283. # for bark_model, device must be verified on its sub-models
  1284. # if has _hf_hook, has been offloaded so the device has to be found in the hook
  1285. if not hasattr(self.semantic, "_hf_hook"):
  1286. return get_parameter_device(self)
  1287. for module in self.semantic.modules():
  1288. if (
  1289. hasattr(module, "_hf_hook")
  1290. and hasattr(module._hf_hook, "execution_device")
  1291. and module._hf_hook.execution_device is not None
  1292. ):
  1293. return torch.device(module._hf_hook.execution_device)
  1294. def enable_cpu_offload(self, gpu_id: Optional[int] = 0):
  1295. r"""
  1296. Offloads all sub-models to CPU using accelerate, reducing memory usage with a low impact on performance. This
  1297. method moves one whole sub-model at a time to the GPU when it is used, and the sub-model remains in GPU until
  1298. the next sub-model runs.
  1299. Args:
  1300. gpu_id (`int`, *optional*, defaults to 0):
  1301. GPU id on which the sub-models will be loaded and offloaded.
  1302. """
  1303. if is_accelerate_available():
  1304. from accelerate import cpu_offload_with_hook
  1305. else:
  1306. raise ImportError("`enable_model_cpu_offload` requires `accelerate`.")
  1307. device = torch.device(f"cuda:{gpu_id}")
  1308. if self.device.type != "cpu":
  1309. self.to("cpu")
  1310. torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
  1311. # this layer is used outside the first foward pass of semantic so need to be loaded before semantic
  1312. self.semantic.input_embeds_layer, _ = cpu_offload_with_hook(self.semantic.input_embeds_layer, device)
  1313. hook = None
  1314. for cpu_offloaded_model in [
  1315. self.semantic,
  1316. self.coarse_acoustics,
  1317. self.fine_acoustics,
  1318. ]:
  1319. _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
  1320. self.fine_acoustics_hook = hook
  1321. _, hook = cpu_offload_with_hook(self.codec_model, device, prev_module_hook=hook)
  1322. # We'll offload the last model manually.
  1323. self.codec_model_hook = hook
  1324. def codec_decode(self, fine_output, output_lengths=None):
  1325. """Turn quantized audio codes into audio array using encodec."""
  1326. fine_output = fine_output.transpose(0, 1)
  1327. emb = self.codec_model.quantizer.decode(fine_output)
  1328. if output_lengths is not None:
  1329. # encodec uses LSTMs which behaves differently with appended padding
  1330. # decoding with encodec takes around 0.1% of the total generation time
  1331. # to keep generation quality, we break batching
  1332. out = [sample[:, :l].unsqueeze(0) for (sample, l) in zip(emb, output_lengths)]
  1333. audio_arr = [self.codec_model.decoder(sample).squeeze() for sample in out]
  1334. else:
  1335. out = self.codec_model.decoder(emb)
  1336. audio_arr = out.squeeze(1) # squeeze the codebook dimension
  1337. return audio_arr
  1338. @torch.no_grad()
  1339. def generate(
  1340. self,
  1341. input_ids: Optional[torch.Tensor] = None,
  1342. history_prompt: Optional[Dict[str, torch.Tensor]] = None,
  1343. return_output_lengths: Optional[bool] = None,
  1344. **kwargs,
  1345. ) -> torch.LongTensor:
  1346. """
  1347. Generates audio from an input prompt and an additional optional `Bark` speaker prompt.
  1348. Args:
  1349. input_ids (`Optional[torch.Tensor]` of shape (batch_size, seq_len), *optional*):
  1350. Input ids. Will be truncated up to 256 tokens. Note that the output audios will be as long as the
  1351. longest generation among the batch.
  1352. history_prompt (`Optional[Dict[str,torch.Tensor]]`, *optional*):
  1353. Optional `Bark` speaker prompt. Note that for now, this model takes only one speaker prompt per batch.
  1354. kwargs (*optional*): Remaining dictionary of keyword arguments. Keyword arguments are of two types:
  1355. - Without a prefix, they will be entered as `**kwargs` for the `generate` method of each sub-model.
  1356. - With a *semantic_*, *coarse_*, *fine_* prefix, they will be input for the `generate` method of the
  1357. semantic, coarse and fine respectively. It has the priority over the keywords without a prefix.
  1358. This means you can, for example, specify a generation strategy for all sub-models except one.
  1359. return_output_lengths (`bool`, *optional*):
  1360. Whether or not to return the waveform lengths. Useful when batching.
  1361. Returns:
  1362. By default:
  1363. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1364. When `return_output_lengths=True`:
  1365. Returns a tuple made of:
  1366. - **audio_waveform** (`torch.Tensor` of shape (batch_size, seq_len)): Generated audio waveform.
  1367. - **output_lengths** (`torch.Tensor` of shape (batch_size)): The length of each waveform in the batch
  1368. Example:
  1369. ```python
  1370. >>> from transformers import AutoProcessor, BarkModel
  1371. >>> processor = AutoProcessor.from_pretrained("suno/bark-small")
  1372. >>> model = BarkModel.from_pretrained("suno/bark-small")
  1373. >>> # To add a voice preset, you can pass `voice_preset` to `BarkProcessor.__call__(...)`
  1374. >>> voice_preset = "v2/en_speaker_6"
  1375. >>> inputs = processor("Hello, my dog is cute, I need him in my life", voice_preset=voice_preset)
  1376. >>> audio_array = model.generate(**inputs, semantic_max_new_tokens=100)
  1377. >>> audio_array = audio_array.cpu().numpy().squeeze()
  1378. ```
  1379. """
  1380. # TODO (joao):workaround until nested generation config is compatible with PreTrained Model
  1381. # todo: dict
  1382. semantic_generation_config = BarkSemanticGenerationConfig(**self.generation_config.semantic_config)
  1383. coarse_generation_config = BarkCoarseGenerationConfig(**self.generation_config.coarse_acoustics_config)
  1384. fine_generation_config = BarkFineGenerationConfig(**self.generation_config.fine_acoustics_config)
  1385. kwargs_semantic = {
  1386. # if "attention_mask" is set, it should not be passed to CoarseModel and FineModel
  1387. "attention_mask": kwargs.pop("attention_mask", None),
  1388. "min_eos_p": kwargs.pop("min_eos_p", None),
  1389. }
  1390. kwargs_coarse = {}
  1391. kwargs_fine = {}
  1392. for key, value in kwargs.items():
  1393. if key.startswith("semantic_"):
  1394. key = key[len("semantic_") :]
  1395. kwargs_semantic[key] = value
  1396. elif key.startswith("coarse_"):
  1397. key = key[len("coarse_") :]
  1398. kwargs_coarse[key] = value
  1399. elif key.startswith("fine_"):
  1400. key = key[len("fine_") :]
  1401. kwargs_fine[key] = value
  1402. else:
  1403. # If the key is already in a specific config, then it's been set with a
  1404. # submodules specific value and we don't override
  1405. if key not in kwargs_semantic:
  1406. kwargs_semantic[key] = value
  1407. if key not in kwargs_coarse:
  1408. kwargs_coarse[key] = value
  1409. if key not in kwargs_fine:
  1410. kwargs_fine[key] = value
  1411. # 1. Generate from the semantic model
  1412. if "generation_config" in kwargs_semantic:
  1413. kwargs_semantic.pop("generation_config")
  1414. semantic_output = self.semantic.generate(
  1415. input_ids,
  1416. history_prompt=history_prompt,
  1417. semantic_generation_config=semantic_generation_config,
  1418. **kwargs_semantic,
  1419. )
  1420. # 2. Generate from the coarse model
  1421. if "generation_config" in kwargs_coarse:
  1422. kwargs_coarse.pop("generation_config")
  1423. coarse_output = self.coarse_acoustics.generate(
  1424. semantic_output,
  1425. history_prompt=history_prompt,
  1426. semantic_generation_config=semantic_generation_config,
  1427. coarse_generation_config=coarse_generation_config,
  1428. codebook_size=self.generation_config.codebook_size,
  1429. return_output_lengths=return_output_lengths,
  1430. **kwargs_coarse,
  1431. )
  1432. output_lengths = None
  1433. if return_output_lengths:
  1434. coarse_output, output_lengths = coarse_output
  1435. # (batch_size, seq_len*coarse_codebooks) -> (batch_size, seq_len)
  1436. output_lengths = output_lengths // coarse_generation_config.n_coarse_codebooks
  1437. # 3. "generate" from the fine model
  1438. if "generation_config" in kwargs_fine:
  1439. kwargs_fine.pop("generation_config")
  1440. output = self.fine_acoustics.generate(
  1441. coarse_output,
  1442. history_prompt=history_prompt,
  1443. semantic_generation_config=semantic_generation_config,
  1444. coarse_generation_config=coarse_generation_config,
  1445. fine_generation_config=fine_generation_config,
  1446. codebook_size=self.generation_config.codebook_size,
  1447. **kwargs_fine,
  1448. )
  1449. if getattr(self, "fine_acoustics_hook", None) is not None:
  1450. # Manually offload fine_acoustics to CPU
  1451. # and load codec_model to GPU
  1452. # since bark doesn't use codec_model forward pass
  1453. self.fine_acoustics_hook.offload()
  1454. self.codec_model = self.codec_model.to(self.device)
  1455. # 4. Decode the output and generate audio array
  1456. audio = self.codec_decode(output, output_lengths)
  1457. if getattr(self, "codec_model_hook", None) is not None:
  1458. # Offload codec_model to CPU
  1459. self.codec_model_hook.offload()
  1460. if return_output_lengths:
  1461. output_lengths = [len(sample) for sample in audio]
  1462. audio = nn.utils.rnn.pad_sequence(audio, batch_first=True, padding_value=0)
  1463. return audio, output_lengths
  1464. return audio
  1465. @classmethod
  1466. def _check_and_enable_flash_attn_2(
  1467. cls,
  1468. config,
  1469. torch_dtype: Optional[torch.dtype] = None,
  1470. device_map: Optional[Union[str, Dict[str, int]]] = None,
  1471. hard_check_only: bool = False,
  1472. check_device_map: bool = False,
  1473. ):
  1474. """
  1475. `_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
  1476. sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
  1477. if necessary.
  1478. If you don't know about Flash Attention, check out the official repository of flash attention:
  1479. https://github.com/Dao-AILab/flash-attention
  1480. For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
  1481. specific section of the documentation to learn more about it:
  1482. https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
  1483. The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
  1484. half precision and not ran on CPU.
  1485. If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model
  1486. can initialize the correct attention module
  1487. """
  1488. config = super()._check_and_enable_flash_attn_2(
  1489. config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
  1490. )
  1491. config.semantic_config._attn_implementation = config._attn_implementation
  1492. config.coarse_acoustics_config._attn_implementation = config._attn_implementation
  1493. config.fine_acoustics_config._attn_implementation = config._attn_implementation
  1494. return config