modeling_jukebox.py 117 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663
  1. # coding=utf-8
  2. # Copyright 2022 The OpenAI Team Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Jukebox model."""
  16. import math
  17. import os
  18. from typing import List, Optional, Tuple
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. from torch import nn
  23. from torch.nn import LayerNorm as FusedLayerNorm
  24. from ....activations import ACT2FN
  25. from ....modeling_utils import PreTrainedModel
  26. from ....utils import add_start_docstrings, logging
  27. from ....utils.logging import tqdm
  28. from .configuration_jukebox import ATTENTION_PATTERNS, JukeboxConfig, JukeboxPriorConfig, JukeboxVQVAEConfig
  29. logger = logging.get_logger(__name__)
  30. def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
  31. """
  32. Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
  33. Args:
  34. logits (`torch.Tensor`):
  35. logits distribution shape (vocabulary size)
  36. top_k (`int`, *optional*, defaults to 0):
  37. When `top_k >0` keep only top key tokens with highest probability (top-k filtering).
  38. top_p (`int`, *optional*, defaults to 0):
  39. When `top_p>0.0` keep the top tokens with cumulative probability >= `top_p` (nucleus filtering).
  40. """
  41. logits = logits.clone()
  42. top_k = min(top_k, logits.size(-1)) # Safety check
  43. if top_k > 0:
  44. # Remove all tokens with a probability less than the last token of the top-k
  45. indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1:]
  46. logits[indices_to_remove] = filter_value
  47. if top_p > 0.0:
  48. sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
  49. cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
  50. # Remove tokens with cumulative probability above the threshold
  51. sorted_indices_to_remove = cumulative_probs > top_p
  52. # Shift the indices to the right to keep also the first token above the threshold
  53. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  54. sorted_indices_to_remove[..., 0] = 0
  55. # indices_to_remove = sorted_indices[sorted_indices_to_remove]
  56. indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
  57. dim=-1, index=sorted_indices, src=sorted_indices_to_remove
  58. )
  59. logits[indices_to_remove] = filter_value
  60. return logits
  61. def get_relevant_lyric_tokens(full_tokens, max_n_lyric_tokens, total_length, offset, duration):
  62. """
  63. Extract only the relevant tokens based on the character position. A total of `max_n_lyric_tokens` tokens will be
  64. returned. If the provided token sequence is smaller, it will be padded, otherwise, only characters ranging from the
  65. midpoint - `max_n_lyric_tokens//2` to the midpoint + `max_n_lyric_tokens//2` will be returned. This *focuses* on
  66. the most relevant tokens (in time) for the sequence.
  67. Args:
  68. full_tokens (`List[int]`):
  69. List containing the token ids of the entire lyrics.
  70. total_length (`int`):
  71. Total expected length of the music (not all of it is generated, see duration), in samples.
  72. offset (`int`):
  73. Starting sample in the music. If the offset is greater than 0, the lyrics will be shifted take that into
  74. account
  75. duration (`int`):
  76. Expected duration of the generated music, in samples. The duration has to be smaller than the total length,
  77. which represent the overall length of the signal,
  78. """
  79. full_tokens = full_tokens[0]
  80. if len(full_tokens) < max_n_lyric_tokens:
  81. tokens = torch.cat(
  82. [torch.zeros(max_n_lyric_tokens - len(full_tokens), dtype=torch.long).to(full_tokens.device), full_tokens]
  83. )
  84. indices = [-1] * (max_n_lyric_tokens - len(full_tokens)) + list(range(0, len(full_tokens)))
  85. else:
  86. midpoint = int(len(full_tokens) * (offset + duration / 2.0) / total_length)
  87. midpoint = min(max(midpoint, max_n_lyric_tokens // 2), len(full_tokens) - max_n_lyric_tokens // 2)
  88. tokens = full_tokens[midpoint - max_n_lyric_tokens // 2 : midpoint + max_n_lyric_tokens // 2]
  89. indices = list(range(midpoint - max_n_lyric_tokens // 2, midpoint + max_n_lyric_tokens // 2))
  90. return tokens.unsqueeze(dim=0), indices
  91. # Break total_length into hops/windows of size n_ctx separated by hop_length
  92. def get_starts(total_length, n_ctx, hop_length):
  93. starts = []
  94. for start in range(0, total_length - n_ctx + hop_length, hop_length):
  95. if start + n_ctx >= total_length:
  96. # Last hop could be smaller, we make it n_ctx to maximise context
  97. start = total_length - n_ctx
  98. starts.append(start)
  99. return starts
  100. def get_alignment(music_tokens, labels, prior, config):
  101. level = prior.levels - 1 # Top level used
  102. n_ctx = prior.n_ctx
  103. tokens = music_tokens[level]
  104. batch_size, total_length = tokens.shape[0], tokens.shape[1]
  105. if total_length < n_ctx:
  106. padding_length = n_ctx - total_length
  107. tokens = torch.cat(
  108. [tokens, torch.zeros(batch_size, n_ctx - total_length, dtype=tokens.dtype, device=tokens.device)], dim=1
  109. )
  110. total_length = tokens.shape[1]
  111. else:
  112. padding_length = 0
  113. hop_length = int(config.hop_fraction[-level - 1] * prior.n_ctx)
  114. alignment_head, alignment_layer = config.prior_alignment_head[0], config.prior_alignment_layer[0]
  115. attn_layers = {alignment_layer}
  116. alignment_hops = {}
  117. indices_hops = {}
  118. for start in tqdm(get_starts(total_length, n_ctx, hop_length), desc="Computing lyric to music alignment "):
  119. end = start + n_ctx
  120. # set metadata offset, sample_length and lyrics tokens
  121. metadata, indices_hop = prior.get_metadata(labels, start, config.sample_length, get_indices=True, offset=0)
  122. tokens_bs = torch.chunk(tokens, batch_size, dim=0)
  123. metadata_bs = torch.chunk(metadata, batch_size, dim=0)
  124. w_hops = []
  125. for tokens_i, metadata_i in zip(tokens_bs, metadata_bs):
  126. w_hop = prior.forward_tokens(tokens_i[:, start:end], [], metadata_i, get_attn_weights=attn_layers)
  127. w_hops.append(w_hop[0][:, alignment_head])
  128. del w_hop
  129. weights = torch.cat(w_hops, dim=0)
  130. del w_hops
  131. alignment_hop = weights.float().cpu().numpy()
  132. del weights
  133. # alignment_hop has shape (bs, n_ctx, nb_relevant_lyric_tokens)
  134. # indices_hop is a list of len=bs, each entry of len hps.nb_relevant_lyric_tokens
  135. indices_hops[start] = indices_hop
  136. alignment_hops[start] = alignment_hop
  137. # Combine attn for each hop into attn for full range
  138. # Use indices to place them into correct place for corresponding source tokens
  139. alignments = []
  140. for item in range(batch_size):
  141. # Note each item has different length lyrics
  142. full_tokens = labels[0, 3:]
  143. alignment = np.zeros((total_length, len(full_tokens) + 1))
  144. for start in reversed(get_starts(total_length, n_ctx, hop_length)):
  145. end = start + n_ctx
  146. alignment_hop = alignment_hops[start][item]
  147. indices = indices_hops[start][item]
  148. alignment[start:end, indices] = alignment_hop
  149. alignment = alignment[: total_length - padding_length, :-1] # remove token padding, and last lyric index
  150. alignments.append(alignment)
  151. return alignments
  152. def save_temp_audio(fname, lvl, metas, aud):
  153. aud = torch.clamp(aud, -1, 1).cpu().numpy()
  154. for i in list(range(aud.shape[0])):
  155. if metas is not None:
  156. artists, genres, lyrics = list(metas)[i].values()
  157. path = f"{fname}/lvl_{lvl}-{artists}-{genres}-{lyrics[:5]}-{i}"
  158. np.save(path, aud[i])
  159. else:
  160. np.save(f"{fname}/lvl_{lvl}-sample-{i}", aud[i])
  161. def get_mask(mask, query_length, key_value_length, blocks, spread, device, sample, sample_t):
  162. # returns a mask of shape 1 x 1 x query_length x key_value_length or None if masking is not needed.
  163. if mask is None or query_length == 1:
  164. return None
  165. offset = sample_t - query_length if sample else max(key_value_length - query_length, 0)
  166. if mask == "autoregressive":
  167. # Masked dense
  168. mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
  169. elif mask == "summary":
  170. # Masked summary
  171. mask = torch.ones(query_length, query_length, device=device).tril()
  172. mask = torch.ones(query_length, query_length, device=device).tril()
  173. mask = mask.view(query_length, blocks, query_length // blocks)[:, :-1, -key_value_length // blocks :]
  174. mask = (
  175. torch.nn.functional.pad(
  176. mask,
  177. (0, 0, 1, 0),
  178. value=1,
  179. )
  180. .contiguous()
  181. .view(query_length, key_value_length)
  182. )
  183. elif mask == "prime":
  184. mask = torch.ones(query_length, key_value_length, device=device).tril(offset)
  185. return mask.view(1, 1, query_length, key_value_length)
  186. class JukeboxConv1D(nn.Module):
  187. def __init__(self, input_width, output_width):
  188. super().__init__()
  189. self.input_width = input_width
  190. self.output_width = output_width
  191. weight = torch.empty(input_width, output_width)
  192. bias = torch.zeros(output_width)
  193. self.weight = nn.Parameter(weight)
  194. self.bias = nn.Parameter(bias)
  195. def forward(self, hidden_states):
  196. size_out = (*hidden_states.size()[:-1], self.output_width)
  197. hidden_states = torch.addmm(
  198. self.bias.type_as(hidden_states),
  199. hidden_states.view(-1, hidden_states.size(-1)),
  200. self.weight.type_as(hidden_states),
  201. )
  202. hidden_states = hidden_states.view(*size_out)
  203. return hidden_states
  204. class JukeboxResConv1DBlock(nn.Module):
  205. def __init__(self, config, conv_width, depth=1, res_scale=1.0):
  206. super().__init__()
  207. hidden_dim = config.res_convolution_multiplier * conv_width
  208. dilation = config.res_dilation_growth_rate**depth
  209. padding = dilation
  210. self.res_scale = res_scale
  211. self.activation = nn.ReLU()
  212. self.conv1d_1 = nn.Conv1d(conv_width, hidden_dim, 3, 1, padding, dilation)
  213. self.conv1d_2 = nn.Conv1d(hidden_dim, conv_width, 1, 1, 0)
  214. def forward(self, hidden_states):
  215. residuals = hidden_states
  216. hidden_states = self.activation(hidden_states)
  217. hidden_states = self.conv1d_1(hidden_states)
  218. hidden_states = self.activation(hidden_states)
  219. hidden_states = self.conv1d_2(hidden_states)
  220. return residuals + self.res_scale * hidden_states
  221. class JukeboxResnet1D(nn.Module):
  222. def __init__(self, config, conv_width, n_depth, reverse_dilation=False):
  223. super().__init__()
  224. self.dilation_cycle = config.res_dilation_cycle
  225. res_scale = 1.0 if not config.conv_res_scale else 1.0 / math.sqrt(n_depth)
  226. blocks = []
  227. for depth in range(n_depth):
  228. block_depth = depth if self.dilation_cycle is None else depth % self.dilation_cycle
  229. blocks.append(JukeboxResConv1DBlock(config, conv_width, block_depth, res_scale))
  230. if reverse_dilation:
  231. blocks = blocks[::-1]
  232. self.resnet_block = nn.ModuleList(blocks)
  233. def forward(self, hidden_states):
  234. for block in self.resnet_block:
  235. hidden_states = block(hidden_states)
  236. return hidden_states
  237. class JukeboxEncoderConvBlock(nn.Module):
  238. def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t):
  239. super().__init__()
  240. blocks = []
  241. filter_t = stride_t * 2
  242. pad_t = stride_t // 2
  243. if down_t > 0:
  244. for i in range(down_t):
  245. blocks.append(nn.Conv1d(embed_dim if i == 0 else hidden_dim, hidden_dim, filter_t, stride_t, pad_t))
  246. blocks.append(JukeboxResnet1D(config, hidden_dim, depth))
  247. self.proj_out = nn.Conv1d(hidden_dim, config.embed_dim, 3, 1, 1)
  248. self.downsample_block = nn.ModuleList(blocks)
  249. def forward(self, hidden_states):
  250. for block in self.downsample_block:
  251. hidden_states = block(hidden_states)
  252. hidden_states = self.proj_out(hidden_states)
  253. return hidden_states
  254. class JukeboxEncoder(nn.Module):
  255. def __init__(self, config, width, depth, levels, downs_t, strides_t):
  256. super().__init__()
  257. self.levels = levels
  258. self.level_blocks = nn.ModuleList()
  259. iterator = zip(list(range(self.levels)), downs_t, strides_t)
  260. for i, down_t, stride_t in iterator:
  261. self.level_blocks.append(
  262. JukeboxEncoderConvBlock(
  263. config, config.conv_input_shape if i == 0 else config.embed_dim, width, depth, down_t, stride_t
  264. )
  265. )
  266. def forward(self, hidden_states):
  267. all_hidden_states = []
  268. # 64, 32, ...
  269. for level in range(self.levels):
  270. level_block = self.level_blocks[level]
  271. hidden_states = level_block(hidden_states)
  272. all_hidden_states.append(hidden_states)
  273. return all_hidden_states
  274. class JukeboxDecoderConvBock(nn.Module):
  275. def __init__(self, config, embed_dim, hidden_dim, depth, down_t, stride_t, reverse_dilation=True):
  276. self.embed_dim = embed_dim
  277. self.hidden_dim = hidden_dim
  278. super().__init__()
  279. blocks = []
  280. if down_t > 0:
  281. filter_t = stride_t * 2
  282. pad_t = stride_t // 2
  283. self.proj_in = nn.Conv1d(embed_dim, hidden_dim, 3, 1, 1)
  284. for i in range(down_t):
  285. blocks.append(JukeboxResnet1D(config, hidden_dim, depth, reverse_dilation))
  286. blocks.append(
  287. nn.ConvTranspose1d(
  288. hidden_dim, hidden_dim if i < down_t - 1 else embed_dim, filter_t, stride_t, pad_t
  289. )
  290. )
  291. self.upsample_block = nn.ModuleList(blocks)
  292. def forward(self, hidden_states):
  293. hidden_states = self.proj_in(hidden_states)
  294. for block in self.upsample_block:
  295. hidden_states = block(hidden_states)
  296. return hidden_states
  297. class JukeboxDecoder(nn.Module):
  298. def __init__(self, config, hidden_dim, depth, levels, downs_t, strides_t):
  299. super().__init__()
  300. self.levels = levels
  301. self.level_blocks = nn.ModuleList()
  302. for level, down_t, stride_t in zip(list(range(self.levels)), downs_t, strides_t):
  303. self.level_blocks.append(
  304. JukeboxDecoderConvBock(config, config.embed_dim, hidden_dim, depth, down_t, stride_t)
  305. )
  306. self.out = nn.Conv1d(config.embed_dim, config.conv_input_shape, 3, 1, 1)
  307. def forward(self, hidden_states, all_levels=True):
  308. hidden_state = hidden_states[-1]
  309. # 32, 64 ...
  310. for level in reversed(range(self.levels)):
  311. level_block = self.level_blocks[level]
  312. hidden_state = level_block(hidden_state)
  313. if level != 0 and all_levels:
  314. hidden_state = hidden_state + hidden_states[level - 1]
  315. hidden_state = self.out(hidden_state)
  316. return hidden_state
  317. class JukeboxBottleneckBlock(nn.Module):
  318. def __init__(self, config: JukeboxVQVAEConfig):
  319. super().__init__()
  320. self.nb_discrete_codes = config.nb_discrete_codes
  321. self.codebook_width = config.embed_dim
  322. self.mu = config.lmu
  323. self.threshold = 1.0
  324. self.init = False
  325. self.codebook_sum = None
  326. self.codebook_elem = None
  327. self.register_buffer("codebook", torch.zeros(self.nb_discrete_codes, self.codebook_width))
  328. def _tile(self, hidden_states):
  329. dim, embed_width = hidden_states.shape
  330. if dim < self.nb_discrete_codes:
  331. n_repeats = (self.nb_discrete_codes + dim - 1) // dim
  332. std = 0.01 / np.sqrt(embed_width)
  333. hidden_states = hidden_states.repeat(n_repeats, 1)
  334. hidden_states = hidden_states + torch.randn_like(hidden_states) * std
  335. return hidden_states
  336. def init_codebook(self, hidden_states):
  337. nb_discrete_codes = self.nb_discrete_codes
  338. self.init = True
  339. codes = self._tile(hidden_states)
  340. self.codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
  341. self.codebook_sum = self.codebook
  342. self.codebook_elem = torch.ones(nb_discrete_codes, device=self.codebook.device)
  343. def update_codebook(self, hidden_states, latent_states):
  344. mu, codebook_width, nb_discrete_codes = self.mu, self.codebook_width, self.nb_discrete_codes
  345. with torch.no_grad():
  346. # Calculate new centres
  347. # nb_discrete_codes, batch_size * seq_length
  348. latent_states_onehot = torch.zeros(nb_discrete_codes, hidden_states.shape[0], device=hidden_states.device)
  349. latent_states_onehot.scatter_(0, latent_states.view(1, hidden_states.shape[0]), 1)
  350. _codebook_sum = torch.matmul(latent_states_onehot, hidden_states)
  351. _codebook_elem = latent_states_onehot.sum(dim=-1) # nb_discrete_codes
  352. codes = self._tile(hidden_states)
  353. _random_codebook = codes[torch.randperm(codes.shape[0])][:nb_discrete_codes]
  354. # Update centres
  355. old_codebook = self.codebook
  356. self.codebook_sum = mu * self.codebook_sum + (1.0 - mu) * _codebook_sum
  357. self.codebook_elem = mu * self.codebook_elem + (1.0 - mu) * _codebook_elem # nb_discrete_codes
  358. usage = (self.codebook_elem.view(nb_discrete_codes, 1) >= self.threshold).float()
  359. norm_code = self.codebook_sum.view(nb_discrete_codes, codebook_width) / self.codebook_elem.view(
  360. nb_discrete_codes, 1
  361. )
  362. self.codebook = usage * (norm_code) + (1 - usage) * _random_codebook
  363. _codebook_prob = _codebook_elem / torch.sum(_codebook_elem) # prob of each bin
  364. entropy = -torch.sum(_codebook_prob * torch.log(_codebook_prob + 1e-8)) # entropy ie how diverse
  365. used_curr = (_codebook_elem >= self.threshold).sum()
  366. usage = torch.sum(usage)
  367. dk = torch.norm(self.codebook - old_codebook) / np.sqrt(np.prod(old_codebook.shape))
  368. return {"entropy": entropy, "used_curr": used_curr, "usage": usage, "dk": dk}
  369. def preprocess(self, hidden_states):
  370. hidden_states = hidden_states.permute(0, 2, 1).contiguous()
  371. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  372. if hidden_states.shape[-1] == self.codebook_width:
  373. prenorm = torch.norm(hidden_states - torch.mean(hidden_states)) / np.sqrt(np.prod(hidden_states.shape))
  374. elif hidden_states.shape[-1] == 2 * self.codebook_width:
  375. x1, x2 = hidden_states[..., : self.codebook_width], hidden_states[..., self.codebook_width :]
  376. prenorm = (torch.norm(x1 - torch.mean(x1)) / np.sqrt(np.prod(x1.shape))) + (
  377. torch.norm(x2 - torch.mean(x2)) / np.sqrt(np.prod(x2.shape))
  378. )
  379. # Normalise
  380. hidden_states = x1 + x2
  381. return hidden_states, prenorm
  382. def postprocess(self, latent_states, dequantised_states, x_shape):
  383. batch_size, time = x_shape
  384. dequantised_states = dequantised_states.view(batch_size, time, -1).permute(0, 2, 1).contiguous()
  385. latent_states = latent_states.view(batch_size, time)
  386. return latent_states, dequantised_states
  387. def quantise(self, latent_states):
  388. # Calculate latent code latent_states
  389. codebook_weights = self.codebook.t()
  390. distance = (
  391. torch.sum(latent_states**2, dim=-1, keepdim=True)
  392. - 2 * torch.matmul(latent_states, codebook_weights)
  393. + torch.sum(codebook_weights**2, dim=0, keepdim=True)
  394. ) # (batch_size * latent_states , codebook_weights)
  395. min_distance, music_tokens = torch.min(distance, dim=-1)
  396. fit = torch.mean(min_distance)
  397. return music_tokens, fit
  398. def dequantise(self, music_tokens):
  399. dequantised_states = F.embedding(music_tokens, self.codebook)
  400. return dequantised_states
  401. def encode(self, latent_states):
  402. samples, _, seq_len = latent_states.shape
  403. # Preprocess.
  404. latent_states, _ = self.preprocess(latent_states)
  405. # Quantise
  406. music_tokens, _ = self.quantise(latent_states)
  407. # Postprocess.
  408. music_tokens = music_tokens.view(samples, seq_len)
  409. return music_tokens
  410. def decode(self, music_tokens):
  411. samples, seq_len = music_tokens.shape
  412. # Dequantise
  413. dequantised_states = self.dequantise(music_tokens)
  414. # Postprocess
  415. dequantised_states = (
  416. dequantised_states.view(samples, seq_len, self.codebook_width).permute(0, 2, 1).contiguous()
  417. )
  418. return dequantised_states
  419. def forward(self, hidden_states, update_codebook=True):
  420. samples, _, seq_len = hidden_states.shape
  421. # Preprocess
  422. hidden_states, prenorm = self.preprocess(hidden_states)
  423. # Init codebook if not inited
  424. if update_codebook and not self.init:
  425. self.init_codebook(hidden_states)
  426. # Quantise and dequantise through bottleneck
  427. music_tokens, fit = self.quantise(hidden_states)
  428. dequantised_states = self.dequantise(music_tokens)
  429. # Update embeddings
  430. if update_codebook:
  431. update_metrics = self.update_codebook(hidden_states, music_tokens)
  432. else:
  433. update_metrics = {}
  434. # Loss
  435. commit_loss = torch.norm(dequantised_states.detach() - hidden_states) ** 2 / np.prod(hidden_states.shape)
  436. # Passthrough
  437. dequantised_states = hidden_states + (dequantised_states - hidden_states).detach()
  438. # Postprocess
  439. music_tokens, dequantised_states = self.postprocess(music_tokens, dequantised_states, (samples, seq_len))
  440. return music_tokens, dequantised_states, commit_loss, dict(fit=fit, pn=prenorm, **update_metrics)
  441. class JukeboxBottleneck(nn.Module):
  442. def __init__(self, config, levels):
  443. super().__init__()
  444. self.levels = levels
  445. self.level_blocks = nn.ModuleList()
  446. for level in range(self.levels):
  447. self.level_blocks.append(JukeboxBottleneckBlock(config))
  448. def encode(self, raw_audio):
  449. music_tokens = [
  450. level_block.encode(hidden_states) for (level_block, hidden_states) in zip(self.level_blocks, raw_audio)
  451. ]
  452. return music_tokens
  453. def decode(self, music_tokens, start_level=0, end_level=None):
  454. if end_level is None:
  455. end_level = self.levels
  456. quantised_audio = [
  457. level_block.decode(z) for (level_block, z) in zip(self.level_blocks[start_level:end_level], music_tokens)
  458. ]
  459. return quantised_audio
  460. def forward(self, input_audio):
  461. music_tokens, quantised_states, commit_losses, metrics = [], [], [], []
  462. for level in range(self.levels):
  463. level_block = self.level_blocks[-level - 1]
  464. hidden_states = input_audio[level]
  465. sampled_tokens, quantised_state, commit_loss, metric = level_block(
  466. hidden_states, update_codebook=self.training
  467. )
  468. music_tokens.append(sampled_tokens)
  469. if not self.training:
  470. # Be extra paranoid and make sure the encoder weights can't
  471. # change from straight-through estimator
  472. quantised_state = quantised_state.detach()
  473. quantised_states.append(quantised_state)
  474. commit_losses.append(commit_loss)
  475. if self.training:
  476. metrics.append(metric)
  477. return music_tokens, quantised_states, commit_losses, metrics
  478. JUKEBOX_START_DOCSTRING = r"""
  479. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  480. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  481. etc.)
  482. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  483. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  484. and behavior.
  485. Parameters:
  486. config (`JukeboxConfig`): Model configuration class with all the parameters of the model.
  487. Initializing with a config file does not load the weights associated with the model, only the
  488. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  489. """
  490. @add_start_docstrings(
  491. """The Hierarchical VQ-VAE model used in Jukebox. This model follows the Hierarchical VQVAE paper from [Will Williams, Sam
  492. Ringer, Tom Ash, John Hughes, David MacLeod, Jamie Dougherty](https://arxiv.org/abs/2002.08111).
  493. """,
  494. JUKEBOX_START_DOCSTRING,
  495. )
  496. class JukeboxVQVAE(PreTrainedModel):
  497. config_class = JukeboxVQVAEConfig
  498. base_model_prefix = "vqvae"
  499. def _init_weights(self, module):
  500. if isinstance(module, nn.Embedding): # embed_tokens
  501. module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
  502. elif isinstance(module, JukeboxConv1D):
  503. if self.config.zero_out:
  504. module.weight.data.zero_()
  505. else:
  506. module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_scale)
  507. elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
  508. module.conv1d_2.weight.data.zero_()
  509. module.conv1d_2.bias.data.zero_()
  510. if isinstance(module, nn.LayerNorm):
  511. module.bias.data.zero_()
  512. module.weight.data.fill_(1.0)
  513. if isinstance(module, nn.Linear) and module.bias is not None:
  514. module.bias.data.zero_()
  515. def __init__(self, config: JukeboxVQVAEConfig):
  516. super().__init__(config)
  517. downs_t = config.res_downs_t
  518. strides_t = config.res_strides_t
  519. if not config.sample_length:
  520. downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
  521. top_raw_to_tokens = np.prod(downsamples)
  522. config.sample_length = (
  523. config.sample_length_in_seconds * config.sampling_rate // top_raw_to_tokens
  524. ) * top_raw_to_tokens
  525. config.sample_length = config.sample_length.astype(int)
  526. self.nb_discrete_codes = config.nb_discrete_codes
  527. self.commit = config.commit
  528. self.sample_length = config.sample_length
  529. self.downsamples = [stride**down for stride, down in zip(strides_t, downs_t)]
  530. self.hop_lengths = np.cumprod(self.downsamples)
  531. self.levels = levels = config.levels
  532. self.music_tokens_shapes = [
  533. (int(self.sample_length // self.hop_lengths[-level - 1])) for level in range(levels)
  534. ]
  535. self.multipliers = config.multipliers if config.multipliers is not None else [1] * levels
  536. self.encoders = nn.ModuleList()
  537. self.decoders = nn.ModuleList()
  538. for level in range(levels):
  539. width = config.res_conv_width * self.multipliers[level]
  540. depth = config.res_conv_depth * self.multipliers[level]
  541. self.encoders.append(
  542. JukeboxEncoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
  543. )
  544. self.decoders.append(
  545. JukeboxDecoder(config, width, depth, level + 1, downs_t[: level + 1], strides_t[: level + 1])
  546. )
  547. self.bottleneck = JukeboxBottleneck(config, levels)
  548. def _decode(self, music_tokens, start_level=0, end_level=None):
  549. # Decode
  550. if end_level is None:
  551. end_level = self.levels
  552. latent_states = self.bottleneck.decode(music_tokens, start_level=start_level, end_level=end_level)
  553. # Use only lowest level
  554. decoder, dequantised_state = self.decoders[start_level], latent_states[0:1]
  555. dequantised_state = decoder(dequantised_state, all_levels=False)
  556. dequantised_state = dequantised_state.permute(0, 2, 1)
  557. return dequantised_state
  558. def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1) -> torch.Tensor:
  559. """
  560. Transforms the input `music_tokens` to their `raw_audio` representation.
  561. Args:
  562. music_tokens (`torch.LongTensor`):
  563. Tensor of music tokens which will be decoded to raw audio by using the codebook. Each music token
  564. should be an index to a corresponding `code` vector in the codebook.
  565. start_level (`int`, *optional*):
  566. Level at which the decoding process will start. Default to 0.
  567. end_level (`int`, *optional*):
  568. Level at which the decoding process will start. Default to None.
  569. bs_chunks (int, *optional*):
  570. Number of chunks to process at the same time.
  571. """
  572. token_chunks = [torch.chunk(token, bs_chunks, dim=0) for token in music_tokens]
  573. dequantised_states = []
  574. for i in range(bs_chunks):
  575. music_tokens_i = [chunks[i] for chunks in token_chunks]
  576. dequantised_state = self._decode(music_tokens_i, start_level=start_level, end_level=end_level)
  577. dequantised_states.append(dequantised_state)
  578. return torch.cat(dequantised_states, dim=0)
  579. def _encode(self, raw_audio, start_level=0, end_level=None):
  580. # Encode
  581. if end_level is None:
  582. end_level = self.levels
  583. input_audio = raw_audio.permute(0, 2, 1).float()
  584. latent_states = []
  585. for level in range(self.levels):
  586. encoder = self.encoders[level]
  587. latent_state = encoder(input_audio)
  588. latent_states.append(latent_state[-1])
  589. music_tokens = self.bottleneck.encode(latent_states)
  590. return music_tokens[start_level:end_level]
  591. def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
  592. """
  593. Transforms the `input_audio` to a discrete representation made out of `music_tokens`.
  594. Args:
  595. input_audio (`torch.Tensor`):
  596. Raw audio which will be encoded to its discrete representation using the codebook. The closest `code`
  597. form the codebook will be computed for each sequence of samples.
  598. start_level (`int`, *optional*, defaults to 0):
  599. Level at which the encoding process will start. Default to 0.
  600. end_level (`int`, *optional*):
  601. Level at which the encoding process will start. Default to None.
  602. bs_chunks (int, *optional*, defaults to 1):
  603. Number of chunks of raw audio to process at the same time.
  604. """
  605. audio_chunks = torch.chunk(input_audio, bs_chunks, dim=0)
  606. music_tokens_list = []
  607. for chunk_i in audio_chunks:
  608. music_tokens_i = self._encode(chunk_i, start_level=start_level, end_level=end_level)
  609. music_tokens_list.append(music_tokens_i)
  610. music_tokens = [torch.cat(music_tokens_level, dim=0) for music_tokens_level in zip(*music_tokens_list)]
  611. return music_tokens
  612. def sample(self, n_samples):
  613. music_tokens = [
  614. torch.randint(0, self.nb_discrete_codes, size=(n_samples, *music_tokens_shape), device="cpu")
  615. for music_tokens_shape in self.music_tokens_shapes
  616. ]
  617. return self.decode(music_tokens)
  618. def forward(self, raw_audio: torch.FloatTensor) -> Tuple[torch.Tensor, torch.Tensor]:
  619. """
  620. Forward pass of the VQ-VAE, encodes the `raw_audio` to latent states, which are then decoded for each level.
  621. The commit loss, which ensure that the encoder's computed embeddings are close to the codebook vectors, is
  622. computed.
  623. Args:
  624. raw_audio (`torch.FloatTensor`):
  625. Audio input which will be encoded and decoded.
  626. Returns:
  627. `Tuple[torch.Tensor, torch.Tensor]`
  628. Example:
  629. ```python
  630. >>> from transformers import JukeboxVQVAE, set_seed
  631. >>> import torch
  632. >>> model = JukeboxVQVAE.from_pretrained("openai/jukebox-1b-lyrics").eval()
  633. >>> set_seed(0)
  634. >>> zs = [torch.randint(100, (4, 1))]
  635. >>> model.decode(zs).shape
  636. torch.Size([4, 8, 1])
  637. ```
  638. """
  639. # Encode/Decode
  640. input_audio = raw_audio.permute(0, 2, 1).float()
  641. latent_states = []
  642. for level in range(self.levels):
  643. encoder = self.encoders[level]
  644. latent_state = encoder(input_audio)
  645. latent_states.append(latent_state[-1])
  646. _, music_tokens, commit_losses, _ = self.bottleneck(latent_states)
  647. dequantised_states = []
  648. for level in range(self.levels):
  649. decoder = self.decoders[level]
  650. dequantised_state = decoder(music_tokens[level : level + 1], all_levels=False)
  651. dequantised_states.append(dequantised_state.permute(0, 2, 1))
  652. commit_loss = sum(commit_losses)
  653. loss = self.commit * commit_loss
  654. return dequantised_states, loss
  655. class JukeboxMLP(nn.Module):
  656. def __init__(self, config):
  657. # a single channel is always used in original code
  658. super().__init__()
  659. embed_dim = config.hidden_size
  660. hidden_dim = int(config.mlp_multiplier * embed_dim)
  661. self.c_fc = JukeboxConv1D(embed_dim, hidden_dim)
  662. self.c_proj = JukeboxConv1D(hidden_dim, embed_dim)
  663. self.act = ACT2FN[config.act_fn]
  664. self.dropout = nn.Dropout(config.resid_dropout)
  665. def forward(self, hidden_states):
  666. hidden_states = self.c_fc(hidden_states)
  667. hidden_states = self.act(hidden_states)
  668. hidden_states = self.c_proj(hidden_states)
  669. hidden_states = self.dropout(hidden_states)
  670. return hidden_states
  671. class JukeboxLayerNorm(FusedLayerNorm):
  672. def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
  673. super().__init__(normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
  674. self.width = np.prod(normalized_shape)
  675. self.max_numel = 65535 * self.width
  676. def forward(self, input):
  677. if input.numel() > self.max_numel:
  678. return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps).type_as(input)
  679. else:
  680. return super().forward(input).type_as(input)
  681. class JukeboxAttention(nn.Module):
  682. def __init__(self, config, n_ctx, attn_func="dense_attn"):
  683. super().__init__()
  684. self.embed_dim = config.hidden_size
  685. self.n_heads = config.n_heads
  686. self.dropout = config.attn_dropout
  687. hidden_dim = int(config.attention_multiplier * self.embed_dim)
  688. self.head_dim = hidden_dim // config.n_heads
  689. self.n_ctx = n_ctx
  690. self.hidden_dim = hidden_dim
  691. self.scale = self.head_dim**-0.25
  692. self.mask = config.mask
  693. if attn_func == "cross_attention":
  694. self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim)
  695. self.c_enc_kv = JukeboxConv1D(self.embed_dim, hidden_dim * 2)
  696. else:
  697. self.c_attn = JukeboxConv1D(self.embed_dim, hidden_dim * 3)
  698. self.c_proj = JukeboxConv1D(hidden_dim, self.embed_dim)
  699. self.attn_dropout = nn.Dropout(config.attn_dropout)
  700. self.resid_dropout = nn.Dropout(config.resid_dropout)
  701. # Sequence of length seq_len is factored as [blocks, seq_len // blocks]
  702. self.attn_func = attn_func
  703. if attn_func == "cross_attention":
  704. self.qkv = self.decode_qkv
  705. elif attn_func == "prime_attn":
  706. self.qkv = self.prime_qkv
  707. else:
  708. self.qkv = self.factored_qkv
  709. ATTENTION_MAP = {
  710. "dense_attn": (self.dense_attn, "autoregressive"),
  711. "block_attn": (self.block_attn, "autoregressive"),
  712. "transpose_block_attn": (self.transpose_block_attn, "autoregressive"),
  713. "prev_block_attn": (self.prev_block_attn, None),
  714. "summary_attn": (self.summary_attn, "summary"),
  715. "summary_spread_attn": (self.summary_spread_attn, "summary"),
  716. "cross_attention": (self.dense_attn, None),
  717. "prime_attn": (self.prime_attn, "prime"),
  718. }
  719. self.attn, self.attn_mask = ATTENTION_MAP[attn_func]
  720. self.blocks = config.blocks
  721. self.spread = config.spread
  722. if self.blocks is not None:
  723. self.block_ctx = self.n_ctx // self.blocks
  724. self.sample_t = 0
  725. self.cache = {}
  726. self.encoder_len = config.nb_relevant_lyric_tokens # length of the encoder input ids
  727. self.record_attn = False
  728. def _attn(self, query_states, key_states, value_states, sample):
  729. scale = self.scale
  730. if self.training:
  731. attention_weight = torch.matmul(query_states * scale, key_states * scale)
  732. else:
  733. attention_weight = torch.matmul(query_states, key_states)
  734. attention_weight.mul_(scale * scale)
  735. attn_weight_type = attention_weight.dtype
  736. attention_weight = attention_weight.float()
  737. if self.mask:
  738. # Generate appropriate mask to mask out all positions before current
  739. # Might take up lot of memory for dense, so can cache it
  740. mask = get_mask(
  741. self.attn_mask,
  742. query_states.size(-2),
  743. key_states.size(-1),
  744. self.blocks,
  745. self.spread,
  746. attention_weight.device,
  747. sample,
  748. self.sample_t,
  749. )
  750. if mask is not None:
  751. attention_weight = attention_weight * mask + -1e9 * (1 - mask)
  752. attention_prob = F.softmax(attention_weight, dim=-1).type(attn_weight_type)
  753. if self.record_attn:
  754. self.attention_prob = attention_prob
  755. if self.attn_func == "prime_attn":
  756. # only keep music queries and lyrics keys/values
  757. self.attention_prob = self.attention_prob[:, :, self.encoder_len :, : self.encoder_len]
  758. attention_prob = self.attn_dropout(attention_prob)
  759. context_states = torch.matmul(attention_prob, value_states)
  760. return context_states
  761. def merge_heads(self, hidden_states):
  762. hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
  763. new_hidden_states_shape = (*hidden_states.size()[:-2], hidden_states.size(-2) * hidden_states.size(-1))
  764. return hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct merge_states
  765. def split_heads(self, hidden_states, is_key=False):
  766. new_hidden_states_shape = (
  767. *hidden_states.size()[:-1],
  768. self.n_heads,
  769. hidden_states.size(-1) // self.n_heads,
  770. )
  771. hidden_states = hidden_states.view(*new_hidden_states_shape) # in Tensorflow implem: fct split_states
  772. if is_key:
  773. return hidden_states.permute(0, 2, 3, 1)
  774. else:
  775. return hidden_states.permute(0, 2, 1, 3)
  776. def dense_attn(self, query, key, value, sample):
  777. query = self.split_heads(query)
  778. key = self.split_heads(key, is_key=True)
  779. value = self.split_heads(value)
  780. context_states = self._attn(query, key, value, sample)
  781. context_states = self.merge_heads(context_states)
  782. return context_states
  783. def block_attn(self, query, key, value, sample):
  784. block_ctx = self.block_ctx
  785. batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
  786. if sample:
  787. return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
  788. else:
  789. query_length = query.shape[1]
  790. query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
  791. if query_length < seq_len:
  792. seq_len = query_length
  793. key = key[:, -seq_len:].contiguous()
  794. value = value[:, -seq_len:].contiguous()
  795. key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
  796. value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
  797. return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
  798. def transpose_block_attn(self, query, key, value, sample):
  799. block_ctx = self.block_ctx
  800. batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
  801. if sample:
  802. block_len = (seq_len - 1) % block_ctx
  803. key = key[:, block_len::block_ctx, :]
  804. value = value[:, block_len::block_ctx, :]
  805. return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
  806. else:
  807. query_length = query.shape[1]
  808. query = query.view(batch_size, query_length // block_ctx, block_ctx, embed_dim)
  809. query = query.transpose(1, 2).contiguous()
  810. query = query.view(batch_size * block_ctx, query_length // block_ctx, embed_dim)
  811. key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
  812. key = key.transpose(1, 2).contiguous()
  813. key = key.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
  814. value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)
  815. value = value.transpose(1, 2).contiguous()
  816. value = value.view(batch_size * block_ctx, seq_len // block_ctx, embed_dim)
  817. block_attn = self.dense_attn(query, key, value, sample)
  818. block_attn = block_attn.view(batch_size, block_ctx, query_length // block_ctx, embed_dim)
  819. block_attn = block_attn.transpose(1, 2).contiguous()
  820. block_attn = block_attn.view(batch_size, query_length, embed_dim)
  821. return block_attn
  822. def prev_block_attn(self, query, key, value, sample):
  823. block_ctx = self.block_ctx
  824. batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
  825. if sample:
  826. block = (seq_len - 1) // block_ctx
  827. prev_l = (block - 1) * block_ctx
  828. if block > 0:
  829. key = key[:, prev_l : prev_l + block_ctx, :]
  830. value = value[:, prev_l : prev_l + block_ctx, :]
  831. else:
  832. key = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
  833. value = torch.zeros(batch_size, block_ctx, embed_dim, device=query.device, dtype=query.dtype)
  834. return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
  835. else:
  836. query_length = query.shape[1]
  837. query = query.view(batch_size * query_length // block_ctx, block_ctx, embed_dim)
  838. key = key.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
  839. key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0))
  840. key = key.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
  841. value = value.view(batch_size, seq_len // block_ctx, block_ctx, embed_dim)[:, :-1, :, :]
  842. value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0))
  843. value = value.view(batch_size * seq_len // block_ctx, block_ctx, embed_dim)
  844. if query_length < seq_len:
  845. nb_query_blocks = query_length // block_ctx
  846. nb_key_blocks = seq_len // block_ctx
  847. seq_len = query_length
  848. key = key.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
  849. key = key.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
  850. value = value.view(batch_size, nb_key_blocks, block_ctx, embed_dim)[:, -nb_query_blocks:]
  851. value = value.contiguous().view(batch_size * nb_query_blocks, block_ctx, embed_dim)
  852. return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
  853. def summary_attn(self, query, key, value, sample):
  854. blocks = self.blocks
  855. block_ctx = self.block_ctx
  856. batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
  857. if sample:
  858. key = key[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :]
  859. key = torch.nn.functional.pad(key, (0, 0, 1, 0))
  860. value = value[:, block_ctx - 1 : blocks * block_ctx - 1 : block_ctx, :]
  861. value = torch.nn.functional.pad(value, (0, 0, 1, 0))
  862. return self.dense_attn(query, key, value, sample).view(batch_size, 1, embed_dim)
  863. else:
  864. key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
  865. key = torch.nn.functional.pad(key, (0, 0, 1, 0)) # batch_size, blocks, embed_dim
  866. value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -1, :]
  867. value = torch.nn.functional.pad(value, (0, 0, 1, 0)) # batch_size, blocks, embed_dim
  868. return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
  869. def summary_spread_attn(self, query, key, value, sample):
  870. blocks = self.blocks
  871. spread = self.spread
  872. batch_size, seq_len, embed_dim = value.shape # For sample, query_len= 1, key_len = value_len = sample_t
  873. if sample:
  874. raise NotImplementedError
  875. else:
  876. key = key.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
  877. key = torch.nn.functional.pad(key, (0, 0, 0, 0, 1, 0)).contiguous()
  878. key = key.view(batch_size, blocks * spread, embed_dim)
  879. value = value.view(batch_size, blocks, seq_len // blocks, embed_dim)[:, :-1, -spread:, :]
  880. value = torch.nn.functional.pad(value, (0, 0, 0, 0, 1, 0)).contiguous()
  881. value = value.view(batch_size, blocks * spread, embed_dim)
  882. return self.dense_attn(query, key, value, sample).view(batch_size, seq_len, embed_dim)
  883. def prime_attn(self, query, key, value, sample):
  884. encoder_len = self._encoder_len
  885. key = key[:, :encoder_len]
  886. value = value[:, :encoder_len]
  887. return self.dense_attn(query, key, value, sample)
  888. def factored_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
  889. curr_ctx = hidden_states.shape[1]
  890. if last_encoder_hidden_states is not None:
  891. raise TypeError("last_encoder_hidden_states should be None")
  892. query, key, value = hidden_states.chunk(3, dim=2)
  893. if sample:
  894. self.sample_t += curr_ctx
  895. key, value = self._append_cache(key, value)
  896. l_cache = self._suff_cache_len()
  897. if self._cache_len() > l_cache:
  898. self._slice_cache(-l_cache)
  899. if curr_ctx > 1:
  900. if self.attn_func != "dense_attn":
  901. query = self._pad_to_block_ctx(query, query=True)
  902. key = self._pad_to_block_ctx(key)
  903. value = self._pad_to_block_ctx(value)
  904. sample = False
  905. else:
  906. key = self.cache["key"]
  907. value = self.cache["value"]
  908. return query, key, value, sample
  909. def prime_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
  910. curr_ctx = hidden_states.shape[1]
  911. if last_encoder_hidden_states is not None:
  912. raise TypeError("last_encoder_hidden_states should be None")
  913. query, key, value = hidden_states.chunk(3, dim=2)
  914. if sample:
  915. if self._cache_len() < self._encoder_len:
  916. self._append_cache(key, value)
  917. if self._cache_len() > self._encoder_len:
  918. self._slice_cache(0, self._encoder_len)
  919. key, value = self.cache["key"], self.cache["value"]
  920. self.sample_t += curr_ctx
  921. return query, key, value, sample
  922. def decode_qkv(self, hidden_states, last_encoder_hidden_states=None, sample=False):
  923. curr_ctx = hidden_states.shape[1]
  924. query = hidden_states
  925. if sample:
  926. if self.sample_t == 0:
  927. self.cache["key"], self.cache["value"] = self.c_enc_kv(
  928. last_encoder_hidden_states.type_as(hidden_states)
  929. ).chunk(2, dim=2)
  930. key, value = self.cache["key"], self.cache["value"]
  931. self.sample_t += curr_ctx
  932. else:
  933. key, value = self.c_enc_kv(last_encoder_hidden_states.type_as(hidden_states)).chunk(2, dim=2)
  934. return query, key, value, sample
  935. def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
  936. curr_ctx = hidden_states.shape[1]
  937. hidden_states = self.c_attn(hidden_states)
  938. query, key, value, sample = self.qkv(
  939. hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
  940. )
  941. attention_scores = self.attn(query, key, value, sample)
  942. if attention_scores.shape[1] != curr_ctx:
  943. offset = self._offset(curr_ctx)
  944. attention_scores = attention_scores[:, offset : offset + curr_ctx, :].contiguous()
  945. attention_scores = self.c_proj(attention_scores)
  946. return self.resid_dropout(attention_scores)
  947. @property
  948. def _encoder_len(self):
  949. encoder_len = self.encoder_len
  950. encoder_blocks = (encoder_len // self.blocks) + 1
  951. return encoder_blocks * self.blocks
  952. def _offset(self, curr_ctx):
  953. if self.attn_func == "dense_attn":
  954. return 0
  955. return (self.sample_t - curr_ctx) % self.block_ctx
  956. def _pad_to_block_ctx(self, hidden_states, query=False):
  957. seq_len = hidden_states.shape[1]
  958. offset = self._offset(seq_len) if query else 0
  959. n_blocks = (seq_len + offset + self.block_ctx - 1) // self.block_ctx
  960. pad = n_blocks * self.block_ctx - seq_len - offset
  961. if pad == 0 and offset == 0:
  962. return hidden_states
  963. else:
  964. return F.pad(hidden_states, (0, 0, offset, pad))
  965. def _cache_len(self):
  966. return 0 if "key" not in self.cache else self.cache["key"].shape[1]
  967. def _suff_cache_len(self):
  968. """
  969. Precondition:
  970. key and value are appended with the current context and self.sample_t reflects the 1-indexed sample
  971. location in the context.
  972. """
  973. previous_block_length = (self.sample_t - 1) % self.block_ctx + 1 + self.block_ctx
  974. REQUIRED_CACHE_LEN = {
  975. "dense_attn": self.sample_t,
  976. "block_attn": (self.sample_t - 1) % self.block_ctx + 1,
  977. "transpose_block_attn": self.sample_t,
  978. "prev_block_attn": self.sample_t if self.sample_t <= self.block_ctx else previous_block_length,
  979. "cross_attn": self.encoder_len,
  980. "prime_attn": min(self.sample_t, self._encoder_len),
  981. }
  982. return REQUIRED_CACHE_LEN[self.attn_func]
  983. def _slice_cache(self, start, end=None):
  984. self.cache["key"] = self.cache["key"][:, start:end]
  985. self.cache["value"] = self.cache["value"][:, start:end]
  986. def _append_cache(self, key, value):
  987. if "key" not in self.cache:
  988. self.cache["key"] = key
  989. self.cache["value"] = value
  990. else:
  991. old_key, old_value = key, value
  992. key = torch.cat([self.cache["key"], old_key], dim=1)
  993. value = torch.cat([self.cache["value"], old_value], dim=1)
  994. del self.cache["key"]
  995. del self.cache["value"]
  996. del old_key
  997. del old_value
  998. self.cache["key"] = key
  999. self.cache["value"] = value
  1000. return self.cache["key"], self.cache["value"]
  1001. def del_cache(self):
  1002. self.sample_t = 0
  1003. if "key" in self.cache:
  1004. del self.cache["key"]
  1005. if "value" in self.cache:
  1006. del self.cache["value"]
  1007. self.cache = {}
  1008. class JukeboxBlock(nn.Module):
  1009. def __init__(self, config, n_ctx, attn_func="dense_attn"):
  1010. super().__init__()
  1011. self.width = config.hidden_size
  1012. self.attn = JukeboxAttention(config, n_ctx, attn_func=attn_func)
  1013. self.layer_norm_0 = JukeboxLayerNorm(config.hidden_size)
  1014. self.mlp = JukeboxMLP(config)
  1015. self.layer_norm_1 = JukeboxLayerNorm(config.hidden_size)
  1016. self.res_scale = 1.0 / config.num_layers if config.attn_res_scale else 1.0
  1017. self.attn_func = attn_func
  1018. def forward(self, hidden_states, last_encoder_hidden_states, sample=False):
  1019. residuals = hidden_states
  1020. hidden_states = self.layer_norm_0(hidden_states)
  1021. hidden_states = self.attn(hidden_states, last_encoder_hidden_states, sample)
  1022. output_states = self.layer_norm_1(residuals + hidden_states)
  1023. output_states = self.mlp(output_states)
  1024. if self.res_scale == 1.0:
  1025. output = residuals + hidden_states + output_states
  1026. else:
  1027. output = residuals + self.res_scale * (hidden_states + output_states)
  1028. return output
  1029. class JukeboxLayerStack(nn.Module):
  1030. def __init__(self, config, n_ctx):
  1031. super().__init__()
  1032. self.n_ctx = n_ctx
  1033. self.width = config.hidden_size
  1034. self.num_layers = config.num_layers
  1035. self.blocks = config.blocks
  1036. self.attention_pattern = config.attention_pattern
  1037. if self.blocks is not None:
  1038. self.block_ctx = n_ctx // self.blocks
  1039. self.encoder_len = config.nb_relevant_lyric_tokens
  1040. self.n_heads = config.n_heads
  1041. # Orders of attn_func
  1042. attention_pattern = ATTENTION_PATTERNS[self.attention_pattern]
  1043. self._attn_mods = nn.ModuleList()
  1044. for depth in range(self.num_layers):
  1045. self._attn_mods.append(JukeboxBlock(config, n_ctx, attn_func=attention_pattern(depth)))
  1046. self.saved_attn_weights = []
  1047. def set_record_attn(self, record_attn):
  1048. """
  1049. Makes forward prop dump self-attention softmaxes to self.saved_attn_weights.
  1050. Args:
  1051. record_attn (`Union[bool,set]`):
  1052. Either a set of layer indices indicating which layers to store, or a boolean value indicating Whether
  1053. to dump all.
  1054. """
  1055. def _should_record_attn(layer_idx):
  1056. if isinstance(record_attn, bool):
  1057. return record_attn
  1058. return layer_idx in record_attn
  1059. for i, layer in enumerate(self._attn_mods):
  1060. layer.attn.record_attn = _should_record_attn(i)
  1061. if not record_attn:
  1062. self.saved_attn_weights = []
  1063. def forward(self, hidden_states, last_encoder_hidden_states=None, sample=False):
  1064. # Blocks
  1065. for i, attn_layer in enumerate(self._attn_mods):
  1066. if attn_layer.attn_func == "cross_attention": # attend to the lyrics
  1067. hidden_states = attn_layer(
  1068. hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=sample
  1069. )
  1070. else:
  1071. hidden_states = attn_layer(hidden_states, last_encoder_hidden_states=None, sample=sample)
  1072. if attn_layer.attn.record_attn:
  1073. self.saved_attn_weights.append(attn_layer.attn.c_attn.weight)
  1074. return hidden_states
  1075. def del_cache(self):
  1076. for attn_layer in self._attn_mods:
  1077. attn_layer.attn.del_cache()
  1078. class JukeboxPositionalEmbedding(nn.Module):
  1079. def __init__(self, embed_dim, width):
  1080. super().__init__()
  1081. self.pos_emb = nn.Parameter(torch.empty((embed_dim, width)))
  1082. def forward(self):
  1083. pos_emb = self.pos_emb
  1084. return pos_emb
  1085. class JukeboxConditionalAutoregressive(nn.Module):
  1086. def __init__(
  1087. self,
  1088. config,
  1089. n_ctx=None,
  1090. embed_dim=None,
  1091. audio_conditioning=False,
  1092. metadata_conditioning=False,
  1093. is_encoder=False,
  1094. ):
  1095. """
  1096. Autoregressive model on either lyric tokens or music tokens, or both. The attention pattern should be properly
  1097. set fro each configuration.
  1098. Args:
  1099. config (`JukeboxPriorConfig`):
  1100. Model configuration class with all the parameters of the model. Initializing with a config file does
  1101. not load the weights associated with the model, only the configuration. Check out the
  1102. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1103. n_ctx (`int`, *optional*):
  1104. Number of tokens or lyrics tokens provided in a single pass.
  1105. embed_dim (`int`, *optional*):
  1106. Either equals to the dimension of the codebook, or the sum of n_vocab (lyrics) and codeboook dimension,
  1107. if the model combines lyrics and music tokens, or simply n_vocab if the model is a seperate encoder
  1108. audio_conditioning (`bool`, *optional*, defaults to `False`):
  1109. Whether or not the prior supports conditionning on audio.
  1110. metadata_conditioning (`bool`, *optional*, defaults to `False`):
  1111. Whether or not the prior supports conditionning on artitst, genres, lyrics and timing.
  1112. is_encoder (`bool`, *optional*, defaults to `False`):
  1113. Whether the model is an encoder only model.
  1114. """
  1115. super().__init__()
  1116. self.width = config.hidden_size
  1117. self.num_layers = config.num_layers
  1118. self.n_ctx = n_ctx if n_ctx is not None else config.n_ctx
  1119. self.embed_dim = embed_dim if embed_dim is not None else config.music_vocab_size
  1120. self.embed_tokens = nn.Embedding(self.embed_dim, config.hidden_size)
  1121. self.embed_tokens_dropout = nn.Dropout(config.emb_dropout)
  1122. self.metadata_conditioning = metadata_conditioning
  1123. self.audio_conditioning = audio_conditioning
  1124. if not metadata_conditioning:
  1125. self.start_token = nn.Parameter(torch.empty((1, config.hidden_size)))
  1126. self.pos_emb = JukeboxPositionalEmbedding(self.n_ctx, config.hidden_size)
  1127. self.pos_emb_dropout = nn.Dropout(config.emb_dropout)
  1128. self.transformer = JukeboxLayerStack(config, n_ctx=self.n_ctx)
  1129. self.is_encoder = is_encoder
  1130. self.encoder_len = config.nb_relevant_lyric_tokens
  1131. if config.merged_decoder:
  1132. # Merged piped model uses this setup
  1133. self.add_cond_after_transformer = False
  1134. self.share_embed_tokens_fc_proj_out = False
  1135. else:
  1136. self.add_cond_after_transformer = True
  1137. self.share_embed_tokens_fc_proj_out = True
  1138. if not is_encoder:
  1139. self.fc_proj_out = nn.Linear(config.hidden_size, self.embed_dim, bias=False)
  1140. if self.share_embed_tokens_fc_proj_out:
  1141. self.fc_proj_out.weight = self.embed_tokens.weight
  1142. self.loss = torch.nn.CrossEntropyLoss()
  1143. def forward(
  1144. self,
  1145. tokens,
  1146. audio_conditioning=None,
  1147. metadata_conditioning=None,
  1148. last_encoder_hidden_states=None,
  1149. get_preds=False,
  1150. get_acts=False,
  1151. get_sep_loss=False,
  1152. ):
  1153. """
  1154. Args:
  1155. tokens (`torch.tensor`):
  1156. Can represent music tokens, lyrics tokens or both, depending on the configuration.
  1157. """
  1158. # Preprocess.
  1159. batch_size = tokens.shape[0]
  1160. with torch.no_grad():
  1161. tokens = tokens.view(batch_size, -1).long()
  1162. if not self.audio_conditioning:
  1163. audio_conditioning = torch.zeros(
  1164. (batch_size, 1, self.width),
  1165. device=tokens.device,
  1166. dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype,
  1167. )
  1168. target = tokens # Target
  1169. hidden_states = self.embed_tokens(tokens)
  1170. # Shift by 1, and fill in start token
  1171. hidden_states = torch.cat((hidden_states[:, -1:], hidden_states[:, :-1]), dim=1)
  1172. if self.metadata_conditioning:
  1173. hidden_states[:, 0] = metadata_conditioning.view(batch_size, self.width)
  1174. else:
  1175. hidden_states[:, 0] = self.start_token
  1176. hidden_states = (
  1177. self.embed_tokens_dropout(hidden_states) + self.pos_emb_dropout(self.pos_emb()) + audio_conditioning
  1178. ) # Pos emb and dropout
  1179. hidden_states = self.transformer(
  1180. hidden_states, last_encoder_hidden_states=last_encoder_hidden_states
  1181. ) # Transformer
  1182. if self.add_cond_after_transformer: # Piped doesnt add x_cond
  1183. hidden_states = hidden_states + audio_conditioning
  1184. activations = hidden_states
  1185. if self.is_encoder:
  1186. return hidden_states
  1187. hidden_states = self.fc_proj_out(hidden_states) # Predictions
  1188. loss_fn = nn.CrossEntropyLoss()
  1189. if get_sep_loss:
  1190. lyric_hidden_states = hidden_states[:, : self.encoder_len].reshape(-1, self.embed_dim)
  1191. token_hidden_states = hidden_states[:, self.encoder_len :].reshape(-1, self.embed_dim)
  1192. lyric_loss = loss_fn(lyric_hidden_states, target[:, : self.encoder_len].reshape(-1)) / np.log(2.0)
  1193. music_token_loss = loss_fn(token_hidden_states, target[:, self.encoder_len :].reshape(-1)) / np.log(2.0)
  1194. loss = (lyric_loss, music_token_loss) # Note order! Lyric is first
  1195. else:
  1196. loss = loss_fn(hidden_states.view(-1, self.embed_dim), target.view(-1)) / np.log(2.0) # Loss
  1197. if get_preds:
  1198. return loss, hidden_states
  1199. elif get_acts:
  1200. return loss, activations
  1201. else:
  1202. return loss, None
  1203. def get_emb(self, sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning):
  1204. if sample_t == 0:
  1205. hidden_states = torch.empty(n_samples, 1, self.width, dtype=self.embed_tokens.weight.dtype).to(
  1206. self.embed_tokens.weight.device
  1207. )
  1208. if self.metadata_conditioning:
  1209. hidden_states[:, 0] = metadata_conditioning.view(n_samples, self.width)
  1210. else:
  1211. hidden_states[:, 0] = self.start_token
  1212. else:
  1213. hidden_states = self.embed_tokens(tokens)
  1214. if audio_conditioning.shape == (n_samples, self.n_ctx, self.width):
  1215. cond = audio_conditioning[:, sample_t : sample_t + 1, :]
  1216. else:
  1217. cond = audio_conditioning
  1218. # Pos emb, dropout is identity at eval time
  1219. hidden_states = hidden_states + self.pos_emb()[sample_t : sample_t + 1] + cond
  1220. return hidden_states, cond
  1221. def sample(
  1222. self,
  1223. n_samples,
  1224. audio_conditioning=None,
  1225. metadata_conditioning=None,
  1226. last_encoder_hidden_states=None,
  1227. temp=1.0,
  1228. top_k=0,
  1229. top_p=0.0,
  1230. get_preds=False,
  1231. sample_tokens=None,
  1232. ):
  1233. if sample_tokens is None:
  1234. sample_tokens = self.n_ctx
  1235. if not self.audio_conditioning:
  1236. audio_conditioning = torch.zeros(
  1237. (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
  1238. ).to(self.fc_proj_out.device)
  1239. with torch.no_grad():
  1240. sampled_tokens = []
  1241. tokens = None
  1242. if get_preds:
  1243. preds = []
  1244. iter = tqdm(range(0, sample_tokens), leave=False)
  1245. for sample_t in iter:
  1246. iter.set_description(f"Ancestral sampling {sample_tokens} music tokens", refresh=True)
  1247. hidden_states, cond = self.get_emb(
  1248. sample_t, n_samples, tokens, audio_conditioning, metadata_conditioning
  1249. )
  1250. hidden_states = self.transformer(
  1251. hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
  1252. )
  1253. if self.add_cond_after_transformer:
  1254. hidden_states = hidden_states + cond
  1255. hidden_states = self.fc_proj_out(hidden_states) # Predictions
  1256. if get_preds:
  1257. preds.append(hidden_states.clone())
  1258. # Adjust logits
  1259. hidden_states = hidden_states / temp
  1260. hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
  1261. # Sample and replace hidden_states
  1262. tokens = torch.distributions.Categorical(logits=hidden_states).sample()
  1263. sampled_tokens.append(tokens.clone())
  1264. del tokens
  1265. self.transformer.del_cache()
  1266. tokens = torch.cat(sampled_tokens, dim=1)
  1267. if get_preds:
  1268. preds = torch.cat(preds, dim=1)
  1269. if get_preds:
  1270. return tokens, preds
  1271. else:
  1272. return tokens
  1273. def split_chunks(self, length, chunk_size):
  1274. n_passes = (length + chunk_size - 1) // chunk_size
  1275. chunk_sizes = [*[chunk_size] * (n_passes - 1), (length - 1) % chunk_size + 1]
  1276. return chunk_sizes
  1277. def primed_sample(
  1278. self,
  1279. n_samples,
  1280. lyric_and_music_tokens,
  1281. audio_conditioning=None,
  1282. metadata_conditioning=None,
  1283. last_encoder_hidden_states=None,
  1284. temp=1.0,
  1285. top_k=0,
  1286. top_p=0.0,
  1287. get_preds=False,
  1288. chunk_size=None,
  1289. sample_tokens=None,
  1290. ):
  1291. if sample_tokens is None:
  1292. sample_tokens = self.n_ctx
  1293. # Preprocess.
  1294. batch_size = lyric_and_music_tokens.shape[0]
  1295. with torch.no_grad():
  1296. lyric_and_music_tokens = lyric_and_music_tokens.view(batch_size, -1).long()
  1297. sampled_audio = torch.split(lyric_and_music_tokens, 1, dim=1)
  1298. sampled_audio = list(sampled_audio)
  1299. if not self.audio_conditioning:
  1300. audio_conditioning = torch.zeros(
  1301. (n_samples, 1, self.width), dtype=self.transformer._attn_mods[0].mlp.c_fc.weight.dtype
  1302. ).to(lyric_and_music_tokens.device)
  1303. with torch.no_grad():
  1304. if get_preds:
  1305. preds = []
  1306. # Fill up key/value cache for past context by runing forward pass.
  1307. # We do so in chunks instead of doing the whole past in one forward pass to reduce max memory usage.
  1308. if chunk_size is None:
  1309. chunk_size = len(sampled_audio)
  1310. chunk_sizes = self.split_chunks(len(sampled_audio), chunk_size)
  1311. x_primes = []
  1312. start = 0
  1313. token = None
  1314. for current_chunk_size in tqdm(chunk_sizes, desc="Preparing past key value", leave=False):
  1315. sampled_audio_prime, conds_prime = [], []
  1316. for sample_t in range(start, start + current_chunk_size):
  1317. x_prime, cond_prime = self.get_emb(
  1318. sample_t, n_samples, token, audio_conditioning, metadata_conditioning
  1319. )
  1320. token = sampled_audio[sample_t]
  1321. sampled_audio_prime.append(x_prime)
  1322. conds_prime.append(cond_prime)
  1323. start = start + current_chunk_size
  1324. x_prime, cond_prime = torch.cat(sampled_audio_prime, dim=1), torch.cat(conds_prime, dim=1)
  1325. del sampled_audio_prime
  1326. del conds_prime
  1327. if not get_preds:
  1328. del cond_prime
  1329. x_prime = self.transformer(x_prime, last_encoder_hidden_states=last_encoder_hidden_states, sample=True)
  1330. if get_preds:
  1331. if self.add_cond_after_transformer:
  1332. x_prime = x_prime + cond_prime
  1333. del cond_prime
  1334. x_primes.append(x_prime)
  1335. else:
  1336. del x_prime
  1337. if get_preds:
  1338. x_prime = torch.cat(x_primes, dim=1)
  1339. x_prime = self.fc_proj_out(x_prime) # Predictions
  1340. preds.append(x_prime)
  1341. # the input of the encoder and decoder can be merged into (lyrics, music tokens)
  1342. input_tokens = sampled_audio[-1]
  1343. itererator = tqdm(
  1344. range(len(sampled_audio), sample_tokens),
  1345. desc=f"Sampling {len(range(len(sampled_audio), sample_tokens))} music tokens",
  1346. leave=False,
  1347. )
  1348. for sample_t in itererator:
  1349. hidden_states, cond = self.get_emb(
  1350. sample_t, n_samples, input_tokens, audio_conditioning, metadata_conditioning
  1351. )
  1352. hidden_states = self.transformer(
  1353. hidden_states, last_encoder_hidden_states=last_encoder_hidden_states, sample=True
  1354. )
  1355. if self.add_cond_after_transformer:
  1356. hidden_states = hidden_states + cond
  1357. hidden_states = self.fc_proj_out(hidden_states) # Predictions
  1358. if get_preds:
  1359. preds.append(hidden_states)
  1360. # Adjust logits
  1361. hidden_states = hidden_states / temp
  1362. hidden_states = filter_logits(hidden_states, top_k=top_k, top_p=top_p)
  1363. # only music tokens are sampled
  1364. music_tokens = torch.distributions.Categorical(logits=hidden_states).sample()
  1365. sampled_audio.append(music_tokens.clone())
  1366. input_tokens = music_tokens
  1367. del input_tokens, music_tokens
  1368. self.transformer.del_cache()
  1369. music_tokens = torch.cat(sampled_audio, dim=1)
  1370. if get_preds:
  1371. preds = torch.cat(preds, dim=1)
  1372. if get_preds:
  1373. return music_tokens, preds
  1374. else:
  1375. return music_tokens
  1376. class JukeboxMusicTokenConditioner(nn.Module):
  1377. """
  1378. The `JukeboxMusicTokenConditioner` takes music tokens as an input (coresponding to the codes of the VQVAE's
  1379. codebook) and upsamples it using a single layer of decoder convolution block (the same is used in the VQVAE).
  1380. """
  1381. def __init__(self, config, level):
  1382. super().__init__()
  1383. self.embed_tokens = nn.Embedding(config.music_vocab_size, config.hidden_size)
  1384. config.embed_dim = config.music_vocab_size # setting correct argument for the `JukeboxDecoder`
  1385. self.upsampler = JukeboxDecoderConvBock(
  1386. config,
  1387. config.hidden_size,
  1388. config.res_conv_width,
  1389. config.res_conv_depth,
  1390. config.res_downs_t[level],
  1391. config.res_strides_t[level],
  1392. reverse_dilation=False,
  1393. )
  1394. self.layer_norm = JukeboxLayerNorm(config.hidden_size)
  1395. def forward(self, music_tokens, raw_audio_conditionning=None):
  1396. """
  1397. Args:
  1398. music_tokens (`torch.LongTensor`):
  1399. Music tokens form the uper level in range(nb_discrete_codes)
  1400. raw_audio_conditionning (`torch.LongTensor`, *optional*):
  1401. Audio used when primed sampling, raw audio information that conditions the generation
  1402. """
  1403. if raw_audio_conditionning is None:
  1404. raw_audio_conditionning = 0.0
  1405. # Embed music_tokens
  1406. music_tokens = music_tokens.long()
  1407. hidden_states = self.embed_tokens(music_tokens)
  1408. hidden_states = hidden_states + raw_audio_conditionning
  1409. # Run conditioner
  1410. hidden_states = hidden_states.permute(0, 2, 1)
  1411. hidden_states = self.upsampler(hidden_states)
  1412. hidden_states = hidden_states.permute(0, 2, 1)
  1413. hidden_states = self.layer_norm(hidden_states)
  1414. return hidden_states
  1415. class JukeboxRangeEmbedding(nn.Module):
  1416. """
  1417. The `JukeboxRangeEmbedding` interpolate the given [pos_start, pos_end] to obtain an equivalent of time positional
  1418. embedding of length `n_ctx`.
  1419. Binning process : For each pos in position tensor, find its bin [start,end) mapped to [0,1,...,bins-1] [start,end)
  1420. -> [0,1) -> [0, bins) -> floor -> [0,...,bins-1] NOTE: Open ended interval on right, so start <= pos < end, not <=
  1421. end
  1422. """
  1423. def __init__(self, n_time, embed_dim, range, out_width, clamp=False):
  1424. super().__init__()
  1425. self.n_time = n_time
  1426. self.embed_dim = embed_dim
  1427. self.emb = nn.Embedding(embed_dim, out_width)
  1428. self.pos_min, self.pos_max = range
  1429. self.clamp = clamp
  1430. def forward(self, pos_start, pos_end=None):
  1431. # Check if [pos_start,pos_end] in [pos_min, pos_max)
  1432. if not len(pos_start.shape) == 2:
  1433. raise TypeError(f"Expected shape with 2 dims, got {pos_start.shape}")
  1434. if not (self.pos_min <= pos_start).all() and (pos_start < self.pos_max).all():
  1435. raise TypeError(f"Range is [{self.pos_min},{self.pos_max}), got {pos_start}")
  1436. pos_start = pos_start.float()
  1437. if pos_end is not None:
  1438. if self.clamp:
  1439. pos_end = pos_end.clamp(self.pos_min, self.pos_max)
  1440. pos_end = pos_end.float()
  1441. # Interpolate so that [pos_start, ..., pos_end] <-> position tensor of length n_ctx
  1442. n_time = self.n_time
  1443. if n_time != 1:
  1444. interpolation = (
  1445. torch.arange(0, n_time, dtype=torch.float, device=pos_start.device).view(1, n_time) / n_time
  1446. )
  1447. position = pos_start + (pos_end - pos_start) * interpolation
  1448. else:
  1449. position = pos_start
  1450. # Bin each value to bins_
  1451. # [0,1) -> [0,1..,embed_dim) -> [0,1...,embed_dim-1
  1452. normalised_position = (position - self.pos_min) / (self.pos_max - self.pos_min)
  1453. bins_ = (self.embed_dim * normalised_position).floor().long().detach()
  1454. return self.emb(bins_)
  1455. class JukeboxLabelConditioner(nn.Module):
  1456. def __init__(self, config, include_time_signal):
  1457. super().__init__()
  1458. embed_dim = config.hidden_size
  1459. timing_dims = config.timing_dims
  1460. sampling_rate = config.sampling_rate
  1461. nb_genres, nb_artists = config.metadata_dims
  1462. music_tokens_shape = config.n_ctx
  1463. self.max_nb_genres = config.max_nb_genres
  1464. self.bow_genre_emb = nn.Embedding(nb_genres, embed_dim)
  1465. self.artist_emb = nn.Embedding(nb_artists, embed_dim)
  1466. self.include_time_signal = include_time_signal
  1467. if self.include_time_signal:
  1468. total_length_range = (config.min_duration * sampling_rate, config.max_duration * sampling_rate)
  1469. absolute_pos_range = (0.0, config.max_duration * sampling_rate)
  1470. relative_pos_range = (0.0, 1.0)
  1471. self.total_length_emb = JukeboxRangeEmbedding(1, timing_dims, total_length_range, embed_dim)
  1472. self.absolute_pos_emb = JukeboxRangeEmbedding(
  1473. music_tokens_shape, timing_dims, absolute_pos_range, embed_dim
  1474. )
  1475. self.relative_pos_emb = JukeboxRangeEmbedding(
  1476. music_tokens_shape, timing_dims, relative_pos_range, embed_dim, clamp=True
  1477. )
  1478. def forward(self, metadata):
  1479. total_length = metadata[:, 0:1]
  1480. offset = metadata[:, 1:2]
  1481. length = metadata[:, 2:3]
  1482. artist = metadata[:, 3:4]
  1483. genre = metadata[:, 4:]
  1484. # Start embedding of length 1
  1485. artist_emb = self.artist_emb(artist)
  1486. # Empty genre slots are denoted by -1. We mask these out.
  1487. mask = (genre >= 0).float().unsqueeze(2)
  1488. genre_emb = (self.bow_genre_emb(genre.clamp(0)) * mask).sum(dim=1, keepdim=True)
  1489. start_emb = genre_emb + artist_emb
  1490. # Pos embedding of length n_ctx
  1491. if self.include_time_signal:
  1492. start, end = offset, offset + length
  1493. total_length = total_length.float()
  1494. start = start.float()
  1495. end = end.float()
  1496. pos_emb = (
  1497. self.total_length_emb(total_length)
  1498. + self.absolute_pos_emb(start, end)
  1499. + self.relative_pos_emb(start / total_length, end / total_length)
  1500. )
  1501. else:
  1502. pos_emb = None
  1503. return start_emb, pos_emb
  1504. class JukeboxPrior(PreTrainedModel):
  1505. """
  1506. The JukeboxPrior class, which is a wrapper around the various conditioning and the transformer. JukeboxPrior can be
  1507. seen as language models trained on music. They model the next `music token` prediction task. If a (lyric) `encoderù
  1508. is defined, it also models the `next character` prediction on the lyrics. Can be conditionned on timing, artist,
  1509. genre, lyrics and codes from lower-levels Priors.
  1510. Args:
  1511. config (`JukeboxPriorConfig`):
  1512. Model configuration class with all the parameters of the model. Initializing with a config file does not
  1513. load the weights associated with the model, only the configuration. Check out the
  1514. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  1515. level (`int`, *optional*):
  1516. Current level of the Prior. Should be in range `[0,nb_priors]`.
  1517. nb_priors (`int`, *optional*, defaults to 3):
  1518. Total number of priors.
  1519. vqvae_encoder (`Callable`, *optional*):
  1520. Encoding method of the VQVAE encoder used in the forward pass of the model. Passing functions instead of
  1521. the vqvae module to avoid getting the parameters.
  1522. vqvae_decoder (`Callable`, *optional*):
  1523. Decoding method of the VQVAE decoder used in the forward pass of the model. Passing functions instead of
  1524. the vqvae module to avoid getting the parameters.
  1525. """
  1526. config_class = JukeboxPriorConfig
  1527. def _init_weights(self, module):
  1528. init_scale = self.config.init_scale
  1529. if isinstance(module, nn.Embedding):
  1530. module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
  1531. elif isinstance(module, JukeboxConv1D):
  1532. if self.config.zero_out:
  1533. module.weight.data.zero_()
  1534. else:
  1535. module.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
  1536. elif isinstance(module, JukeboxPositionalEmbedding):
  1537. module.pos_emb.data.normal_(mean=0.0, std=0.01 * init_scale)
  1538. elif isinstance(module, JukeboxRangeEmbedding):
  1539. module.emb.weight.data.normal_(mean=0.0, std=0.01 * init_scale)
  1540. elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "lm_head"):
  1541. module.lm_head.weight.data.normal_(mean=0.0, std=0.02 * init_scale)
  1542. elif isinstance(module, JukeboxConditionalAutoregressive) and hasattr(module, "start_token"):
  1543. module.start_token.data.normal_(mean=0.0, std=0.01 * init_scale)
  1544. elif isinstance(module, JukeboxResConv1DBlock) and self.config.zero_out:
  1545. module.conv1d_2.weigth.data.zero_()
  1546. module.conv1d_2.bias.data.zero_()
  1547. if isinstance(module, nn.LayerNorm):
  1548. module.bias.data.zero_()
  1549. module.weight.data.fill_(1.0)
  1550. if isinstance(module, nn.Linear) and module.bias is not None:
  1551. module.bias.data.zero_()
  1552. def __init__(self, config: JukeboxPriorConfig, level=None, nb_priors=3, vqvae_encoder=None, vqvae_decoder=None):
  1553. super().__init__(config)
  1554. # Passing functions instead of the vqvae module to avoid getting params, only used in the
  1555. # forward loop
  1556. self.vqvae_encoder = vqvae_encoder
  1557. self.vqvae_decoder = vqvae_decoder
  1558. self.levels = nb_priors
  1559. self.level = level if level is not None else config.level
  1560. self.base_model_prefix = f"priors.{self.level}"
  1561. self.n_ctx = config.n_ctx
  1562. self.lyric_conditioning = config.nb_relevant_lyric_tokens > 0
  1563. self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens
  1564. self.encoder_loss_fraction = config.encoder_loss_fraction
  1565. # Audio conditioning : conditioning on music tokens (either from audio or from previous levels or both)
  1566. self.audio_conditioning = self.level != 0
  1567. self.cond_level = self.level - 1
  1568. if self.audio_conditioning:
  1569. self.conditioner_blocks = JukeboxMusicTokenConditioner(config, self.level)
  1570. # metadata conditioning : contioning on timing, genres, and artist
  1571. self.metadata_conditioning = config.metadata_conditioning
  1572. if self.metadata_conditioning:
  1573. self.metadata_embedding = JukeboxLabelConditioner(config, include_time_signal=not self.audio_conditioning)
  1574. # define encoder-decoder or encoder and decoder
  1575. self.is_encoder_decoder = config.is_encoder_decoder
  1576. if config.is_encoder_decoder:
  1577. # encoder-decoder transformer
  1578. self.input_shapes = [config.nb_relevant_lyric_tokens, config.n_ctx]
  1579. self.embed_dim_shift = [0, config.lyric_vocab_size]
  1580. self.width = config.hidden_size
  1581. self.nb_relevant_lyric_tokens = config.nb_relevant_lyric_tokens
  1582. self.prior = JukeboxConditionalAutoregressive(
  1583. config,
  1584. n_ctx=config.nb_relevant_lyric_tokens + config.n_ctx,
  1585. embed_dim=config.lyric_vocab_size + config.music_vocab_size,
  1586. audio_conditioning=(self.audio_conditioning or self.metadata_conditioning),
  1587. metadata_conditioning=True,
  1588. )
  1589. else:
  1590. # Separate encoder-decoder transformer
  1591. encoder_config = config.encoder_config
  1592. if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
  1593. self.lyric_acts_width = encoder_config.hidden_size
  1594. self.encoder_width = config.hidden_size
  1595. self.encoder_dim = config.lyric_vocab_size
  1596. self.encoder = JukeboxConditionalAutoregressive(
  1597. encoder_config,
  1598. n_ctx=self.nb_relevant_lyric_tokens,
  1599. embed_dim=self.encoder_dim,
  1600. audio_conditioning=False,
  1601. metadata_conditioning=False,
  1602. is_encoder=True,
  1603. )
  1604. self.encoder.proj_in = JukeboxConv1D(encoder_config.hidden_size, config.hidden_size)
  1605. self.encoder.final_layer_norm = JukeboxLayerNorm(config.hidden_size)
  1606. self.encoder.lm_head = nn.Linear(config.hidden_size, config.lyric_vocab_size, bias=False)
  1607. else:
  1608. self.nb_relevant_lyric_tokens = 0
  1609. # decoder model on the tokens
  1610. self.prior = JukeboxConditionalAutoregressive(
  1611. config,
  1612. audio_conditioning=(self.audio_conditioning or self.metadata_conditioning),
  1613. metadata_conditioning=self.metadata_conditioning,
  1614. )
  1615. self.next_token_prediction_loss_dims = config.n_ctx
  1616. self.total_loss_dims = self.nb_relevant_lyric_tokens + self.next_token_prediction_loss_dims
  1617. self.downsamples = [stride**down for stride, down in zip(config.res_strides_t, config.res_downs_t)]
  1618. self.cond_downsample = self.downsamples[self.level] if self.level != 0 else None
  1619. self.raw_to_tokens = np.prod(self.downsamples[: nb_priors - self.level])
  1620. self.sample_length = self.n_ctx * self.raw_to_tokens
  1621. logger.info(
  1622. f"Level:{self.level}, Cond downsample:{self.cond_downsample}, Raw to tokens:{self.raw_to_tokens}, Sample"
  1623. f" length:{self.sample_length}"
  1624. )
  1625. def get_metadata(self, labels, start, total_length, offset, get_indices=False):
  1626. metadata = labels.clone()
  1627. metadata[:, 0] = total_length
  1628. # Set sample_length to match this level
  1629. metadata[:, 2] = int(self.sample_length)
  1630. # Set offset
  1631. metadata[:, 1:2] = int(offset * self.raw_to_tokens) + int(start * self.raw_to_tokens)
  1632. # here since metadata has the full token_list, we just need to selected the ones that are relevant
  1633. # Set lyric tokens
  1634. metadata, indices = self.set_metadata_lyric_tokens(metadata)
  1635. if get_indices:
  1636. return metadata, indices
  1637. else:
  1638. return metadata
  1639. def set_metadata_lyric_tokens(self, labels):
  1640. """
  1641. Processes the full labels to only retreive the relevant lyric tokens and keep the metadata conditioning tokens.
  1642. """
  1643. if self.nb_relevant_lyric_tokens > 0:
  1644. tokens_list = torch.zeros(
  1645. (labels.shape[0], self.nb_relevant_lyric_tokens), dtype=torch.long, device=labels.device
  1646. )
  1647. indices_list = [] # whats the index of each current character in original array
  1648. for idx in range(labels.shape[0]):
  1649. full_tokens = labels.clone()[:, 4 + self.metadata_embedding.max_nb_genres :]
  1650. total_length, offset, duration = labels[idx, 0], labels[idx, 1], labels[idx, 2]
  1651. tokens, indices = get_relevant_lyric_tokens(
  1652. full_tokens, self.nb_relevant_lyric_tokens, total_length, offset, duration
  1653. )
  1654. tokens_list[idx, :] = tokens
  1655. indices_list.append(indices)
  1656. return (
  1657. torch.cat((labels[:, : 4 + self.metadata_embedding.max_nb_genres], tokens_list), dim=-1),
  1658. indices_list,
  1659. )
  1660. else:
  1661. return labels, None
  1662. def get_music_tokens_conds(self, music_tokens, start, end):
  1663. """
  1664. Extracts current level's conditioning music tokens.
  1665. """
  1666. if self.level != 0:
  1667. music_tokens_cond = music_tokens[self.level - 1]
  1668. music_tokens = music_tokens_cond[:, start // self.cond_downsample : end // self.cond_downsample]
  1669. missing_cond_len = self.n_ctx // self.cond_downsample - music_tokens_cond[-1].shape[-1]
  1670. if missing_cond_len > 0:
  1671. init_cond = torch.zeros(1, missing_cond_len).to(music_tokens_cond.device)
  1672. music_tokens_cond = torch.cat((music_tokens_cond, init_cond), dim=-1).long()
  1673. music_tokens_conds = [music_tokens_cond]
  1674. else:
  1675. music_tokens_conds = None
  1676. return music_tokens_conds
  1677. def prior_preprocess(self, tokens, conds):
  1678. """
  1679. Shifts the input tokens to account for the dictionary merge. The embed_dim_shift give by how much the music
  1680. tokens should be shifted by. It is equal to `lyric_vocab_size`.
  1681. """
  1682. batch_size = tokens[0].shape[0]
  1683. for i in range(len(tokens)):
  1684. tokens[i] = (tokens[i] + int(self.embed_dim_shift[i])).view(batch_size, -1)
  1685. for i in range(len(conds)):
  1686. if conds[i] is None:
  1687. conds[i] = torch.zeros(
  1688. (batch_size, self.input_shapes[i], self.width), dtype=tokens[0].dtype, device=tokens[0].device
  1689. )
  1690. return torch.cat(tokens, dim=1), torch.cat(conds, dim=1)
  1691. def prior_postprocess(self, tokens):
  1692. """
  1693. Shifts back the input tokens if the model uses an encoder decoder architecture. As the embedding layer is
  1694. shared, `prior_embed_dim_shift` shifts the music token ids by `lyric_vocab_size`. Only returns the music
  1695. tokens.
  1696. """
  1697. batch_size = tokens.shape[0]
  1698. dims = (self.input_shapes[0], tokens.shape[1] - self.input_shapes[0])
  1699. tokens = list(torch.split(tokens, dims, dim=1))
  1700. # Some of the input tokens might be shifted to take into account the voccabulary fusion
  1701. for i in range(len(tokens)):
  1702. bins_shift = int(self.embed_dim_shift[i])
  1703. tokens[i] = (tokens[i] - bins_shift).view(batch_size, -1)
  1704. tokens[i] = torch.clamp(tokens[i], min=0)
  1705. # If not masking loss, model may have generated lyric/midi tokens which are now shifted <0 by bin_shift
  1706. return tokens[-1]
  1707. def embed_tokens(self, music_tokens_conds):
  1708. """
  1709. Embeds the upper level music tokens and upsamples them to provide as audio conditioning.
  1710. """
  1711. music_tokens_conds = music_tokens_conds[: self.cond_level + 1]
  1712. audio_conditioning = None
  1713. for music_tokens_cond, conditioner_block in reversed(list(zip(music_tokens_conds, [self.conditioner_blocks]))):
  1714. audio_conditioning = conditioner_block(music_tokens_cond, audio_conditioning)
  1715. return audio_conditioning
  1716. def encode(self, hidden_states, start_level=None, end_level=None, bs_chunks=1):
  1717. """
  1718. Encodes the hidden states (raw audio) using the VQVAE's encoder. Returns latent_states.
  1719. """
  1720. if start_level is None:
  1721. start_level = self.level
  1722. if end_level is None:
  1723. end_level = self.levels
  1724. # Get latents
  1725. with torch.no_grad():
  1726. latent_states = self.vqvae_encoder(
  1727. hidden_states, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
  1728. )
  1729. return latent_states
  1730. def decode(self, music_tokens, start_level=None, end_level=None, bs_chunks=1):
  1731. """
  1732. Usamples the sequence of codebook vectors to a raw audio.
  1733. """
  1734. if start_level is None:
  1735. start_level = self.level
  1736. if end_level is None:
  1737. end_level = self.levels
  1738. with torch.no_grad():
  1739. output = self.vqvae_decoder(
  1740. music_tokens, start_level=start_level, end_level=end_level, bs_chunks=bs_chunks
  1741. )
  1742. return output
  1743. def get_cond(self, music_tokens_conds, metadata):
  1744. """
  1745. Converts the input tokens to input_embeddings. Splits the lyrics form the rest of the metadata. Lyric tokens
  1746. can be None.
  1747. """
  1748. if metadata is not None:
  1749. n_labels = metadata.shape[1] - self.nb_relevant_lyric_tokens
  1750. metadata, lyric_tokens = metadata[:, :n_labels], metadata[:, n_labels:]
  1751. else:
  1752. metadata, lyric_tokens = None, None
  1753. metadata_conditioning, metadata_pos = (
  1754. self.metadata_embedding(metadata) if self.metadata_conditioning else (None, None)
  1755. )
  1756. audio_conditioning = self.embed_tokens(music_tokens_conds) if self.audio_conditioning else metadata_pos
  1757. return audio_conditioning, metadata_conditioning, lyric_tokens
  1758. def sample(
  1759. self,
  1760. n_samples,
  1761. music_tokens=None,
  1762. music_tokens_conds=None,
  1763. metadata=None,
  1764. temp=1.0,
  1765. top_k=0,
  1766. top_p=0.0,
  1767. chunk_size=None,
  1768. sample_tokens=None,
  1769. ):
  1770. """
  1771. Ancestral/Prime sampling a window of tokens using the provided conditioning and metadatas.
  1772. Args:
  1773. n_samples (`int`):
  1774. Number of samples to generate.
  1775. music_tokens (`List[torch.LongTensor]`, *optional*):
  1776. Previously gemerated tokens at the current level. Used as context for the generation.
  1777. music_tokens_conds (`List[torch.FloatTensor]`, *optional*):
  1778. Upper-level music tokens generated by the previous prior model. Is `None` if the generation is not
  1779. conditionned on the upper-level tokens.
  1780. metadata (`List[torch.LongTensor]`, *optional*):
  1781. List containing the metatdata tensor with the artist, genre and the lyric tokens.
  1782. temp (`float`, *optional*, defaults to 1.0):
  1783. Sampling temperature.
  1784. top_k (`int`, *optional*, defaults to 0):
  1785. Top k probabilities used for filtering.
  1786. top_p (`float`, *optional*, defaults to 0.0):
  1787. Top p probabilities used for filtering.
  1788. chunk_size (`int`, *optional*):
  1789. Size of the chunks used to prepare the cache of the transformer.
  1790. sample_tokens (`int`, *optional*):
  1791. Number of tokens to sample.
  1792. """
  1793. no_past_context = music_tokens is None or music_tokens.shape[1] == 0
  1794. name = {True: "Ancestral", False: "Primed"}[no_past_context]
  1795. logger.info(f"{name} sampling {n_samples} samples with temp={temp}, top_k={top_k}, top_p={top_p}")
  1796. with torch.no_grad():
  1797. # Currently audio_conditioning only uses immediately above layer
  1798. audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)
  1799. if self.is_encoder_decoder:
  1800. if no_past_context: # the prime_sample function will be used with music_tokens set to None
  1801. lyric_and_music_tokens, audio_conditioning = self.prior_preprocess(
  1802. [lyric_tokens], [None, audio_conditioning]
  1803. )
  1804. else:
  1805. lyric_and_music_tokens, audio_conditioning = self.prior_preprocess(
  1806. [lyric_tokens, music_tokens], [None, audio_conditioning]
  1807. )
  1808. if sample_tokens is not None:
  1809. sample_tokens += self.nb_relevant_lyric_tokens
  1810. music_tokens = self.prior.primed_sample(
  1811. n_samples,
  1812. lyric_and_music_tokens,
  1813. audio_conditioning,
  1814. metadata_conditioning,
  1815. temp=temp,
  1816. top_k=top_k,
  1817. top_p=top_p,
  1818. chunk_size=chunk_size,
  1819. sample_tokens=sample_tokens,
  1820. )
  1821. music_tokens = self.prior_postprocess(music_tokens)
  1822. else:
  1823. last_encoder_hidden_states = self.get_encoder_states(lyric_tokens, sample=True)
  1824. if no_past_context:
  1825. music_tokens = self.prior.sample(
  1826. n_samples,
  1827. audio_conditioning,
  1828. metadata_conditioning,
  1829. last_encoder_hidden_states,
  1830. temp=temp,
  1831. top_k=top_k,
  1832. top_p=top_p,
  1833. sample_tokens=sample_tokens,
  1834. )
  1835. else:
  1836. music_tokens = self.prior.primed_sample(
  1837. n_samples,
  1838. music_tokens,
  1839. audio_conditioning,
  1840. metadata_conditioning,
  1841. last_encoder_hidden_states,
  1842. temp=temp,
  1843. top_k=top_k,
  1844. top_p=top_p,
  1845. chunk_size=chunk_size,
  1846. sample_tokens=sample_tokens,
  1847. )
  1848. return music_tokens
  1849. def get_encoder_states(self, lyric_tokens, sample=False):
  1850. """
  1851. Retreive the last hidden_states of the lyric encoder that will be attended to by the decoder. Forwards through
  1852. the lyric encoder.
  1853. """
  1854. if self.nb_relevant_lyric_tokens != 0 and self.lyric_conditioning:
  1855. if sample:
  1856. self.encoder = self.encoder.to(lyric_tokens.device)
  1857. lyric_acts = self.encoder(lyric_tokens, None, None, None)
  1858. lyric_acts = self.encoder.proj_in(lyric_acts)
  1859. last_encoder_hidden_states = self.encoder.final_layer_norm(lyric_acts)
  1860. else:
  1861. last_encoder_hidden_states = None
  1862. return last_encoder_hidden_states
  1863. def get_encoder_loss(self, last_encoder_hidden_states, target_lyrics):
  1864. """
  1865. Computes the loss for the lyric encoder: next lyric token prediction.
  1866. """
  1867. if self.lyric_conditioning:
  1868. last_encoder_hidden_states = self.encoder.lm_head(last_encoder_hidden_states)
  1869. encoder_loss = nn.functional.cross_entropy(
  1870. last_encoder_hidden_states.view(-1, self.encoder_dim), target_lyrics.view(-1)
  1871. ) / np.log(2.0)
  1872. else:
  1873. encoder_loss = torch.tensor(0.0, device=last_encoder_hidden_states.device)
  1874. return encoder_loss
  1875. def forward_tokens(
  1876. self, music_tokens, music_tokens_conds=[], metadata=None, get_preds=False, get_attn_weights=False
  1877. ):
  1878. """
  1879. Applies a forward pass using the conditioning tokens. Different from the classic forward as it does not use the
  1880. vqvae's encoding layers.
  1881. """
  1882. if get_attn_weights:
  1883. self.prior.transformer.set_record_attn(get_attn_weights)
  1884. audio_conditioning, metadata_conditioning, lyric_tokens = self.get_cond(music_tokens_conds, metadata)
  1885. if self.is_encoder_decoder: # the preprocess returns the full tokens (Lyrics and Music tokens), shifted
  1886. tokens, audio_conditioning = self.prior_preprocess(
  1887. [lyric_tokens, music_tokens], [None, audio_conditioning]
  1888. )
  1889. (encoder_loss, next_token_prediction_loss), preds = self.prior(
  1890. tokens, audio_conditioning, metadata_conditioning, get_sep_loss=True, get_preds=get_preds
  1891. )
  1892. else:
  1893. last_encoder_hidden_states = self.get_encoder_states(lyric_tokens)
  1894. encoder_loss = self.get_encoder_loss(last_encoder_hidden_states, lyric_tokens)
  1895. next_token_prediction_loss, preds = self.prior(
  1896. music_tokens,
  1897. audio_conditioning,
  1898. metadata_conditioning,
  1899. last_encoder_hidden_states,
  1900. get_preds=get_preds,
  1901. )
  1902. loss = self.encoder_loss_fraction * encoder_loss * self.nb_relevant_lyric_tokens / self.total_loss_dims
  1903. loss += next_token_prediction_loss * self.next_token_prediction_loss_dims / self.total_loss_dims
  1904. metrics = {
  1905. "bpd": next_token_prediction_loss.clone().detach(),
  1906. "encoder_loss": encoder_loss.clone().detach(),
  1907. "next_token_prediction_loss": next_token_prediction_loss.clone().detach(),
  1908. }
  1909. if get_preds:
  1910. metrics["preds"] = preds.clone().detach()
  1911. if get_attn_weights:
  1912. saved_attn_weights = self.prior.transformer.saved_attn_weights
  1913. self.prior.transformer.set_record_attn(False)
  1914. return saved_attn_weights
  1915. else:
  1916. return loss, metrics
  1917. def forward(
  1918. self,
  1919. hidden_states: torch.Tensor,
  1920. metadata: Optional[List[torch.LongTensor]],
  1921. decode: Optional[bool] = False,
  1922. get_preds: Optional[bool] = False,
  1923. ) -> List[torch.Tensor]:
  1924. """
  1925. Encode the hidden states using the `vqvae` encoder, and then predicts the next token in the `forward_tokens`
  1926. function. The loss is the sum of the `encoder` loss and the `decoder` loss.
  1927. Args:
  1928. hidden_states (`torch.Tensor`):
  1929. Hidden states which should be raw audio
  1930. metadata (`List[torch.LongTensor]`, *optional*):
  1931. List containing the metadata conditioning tensorwith the lyric and the metadata tokens.
  1932. decode (`bool`, *optional*, defaults to `False`):
  1933. Whether or not to decode the encoded to tokens.
  1934. get_preds (`bool`, *optional*, defaults to `False`):
  1935. Whether or not to return the actual predicitons of the model.
  1936. """
  1937. batch_size = hidden_states.shape[0]
  1938. music_tokens, *music_tokens_conds = self.encode(hidden_states, bs_chunks=batch_size)
  1939. loss, metrics = self.forward_tokens(
  1940. music_tokens=music_tokens,
  1941. music_tokens_conds=music_tokens_conds,
  1942. metadata=metadata,
  1943. get_preds=get_preds,
  1944. )
  1945. if decode:
  1946. dequantised_states = self.decode([music_tokens, *music_tokens_conds])
  1947. else:
  1948. dequantised_states = None
  1949. return dequantised_states, loss, metrics
  1950. class JukeboxPreTrainedModel(PreTrainedModel):
  1951. """
  1952. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  1953. models.
  1954. """
  1955. config_class = JukeboxConfig
  1956. base_model_prefix = "jukebox"
  1957. supports_gradient_checkpointing = False
  1958. def _init_weights(self, module):
  1959. if isinstance(module, JukeboxPrior) or isinstance(module, JukeboxVQVAE):
  1960. module.apply(module._init_weights)
  1961. def __init__(self, *inputs, **kwargs):
  1962. super().__init__(*inputs, **kwargs)
  1963. JUKEBOX_SAMPLING_INPUT_DOCSTRING = r"""
  1964. labels (`List[torch.LongTensor]` of length `n_sample`, and shape `(self.levels, self.config.max_nb_genre + lyric_sequence_length)` :
  1965. List of metadata such as `artist_id`, `genre_id` and the full list of lyric tokens which are used to
  1966. condition the generation.
  1967. sampling_kwargs (`Dict[Any]`):
  1968. Various additional sampling arguments that are used by the `_sample` function. A detail list of the
  1969. arguments can bee seen in the [`_sample`] function documentation.
  1970. """
  1971. @add_start_docstrings(
  1972. """The bare JUKEBOX Model used for music generation. 4 sampling techniques are supported : `primed_sample`, `upsample`,
  1973. `continue_sample` and `ancestral_sample`. It does not have a `forward` method as the training is not end to end. If
  1974. you want to fine-tune the model, it is recommended to use the `JukeboxPrior` class and train each prior
  1975. individually.
  1976. """,
  1977. JUKEBOX_START_DOCSTRING,
  1978. )
  1979. class JukeboxModel(JukeboxPreTrainedModel):
  1980. _no_split_modules = ["JukeboxBlock"]
  1981. def __init__(self, config):
  1982. super().__init__(config)
  1983. vqvae_config = config.vqvae_config
  1984. self.vqvae = JukeboxVQVAE(vqvae_config)
  1985. self.set_shared_params(config)
  1986. self.priors = nn.ModuleList(
  1987. [JukeboxPrior(config.prior_configs[level], level) for level in range(config.nb_priors)]
  1988. )
  1989. def set_shared_params(self, model_config):
  1990. """
  1991. Initialises the parameters that are shared. This has to be done here because the list of `JukeboxPriorConfig`
  1992. is nest, and is thus unreachable in the `from_dict` function
  1993. """
  1994. for config in model_config.prior_configs:
  1995. config.sampling_rate = model_config.sampling_rate
  1996. config.timing_dims = model_config.timing_dims
  1997. config.min_duration = model_config.min_duration
  1998. config.max_duration = model_config.max_duration
  1999. config.max_nb_genres = model_config.max_nb_genres
  2000. config.metadata_conditioning = model_config.metadata_conditioning
  2001. def decode(self, music_tokens, start_level=0, end_level=None, bs_chunks=1):
  2002. return self.vqvae.decode(music_tokens, start_level, end_level, bs_chunks)
  2003. def encode(self, input_audio, start_level=0, end_level=None, bs_chunks=1):
  2004. return self.vqvae.encode(input_audio, start_level, end_level, bs_chunks)
  2005. def split_batch(self, obj, n_samples, split_size):
  2006. n_passes = (n_samples + split_size - 1) // split_size
  2007. if isinstance(obj, torch.Tensor):
  2008. return torch.split(obj, split_size, dim=0)
  2009. elif isinstance(obj, list):
  2010. return list(zip(*[torch.split(item, split_size, dim=0) for item in obj]))
  2011. elif obj is None:
  2012. return [None] * n_passes
  2013. else:
  2014. raise TypeError("Unknown input type")
  2015. # Sample a partial window of length<n_ctx with tokens_to_sample new tokens on level=level
  2016. def sample_partial_window(
  2017. self, music_tokens, labels, offset, sampling_kwargs, level, tokens_to_sample, max_batch_size
  2018. ):
  2019. prior = self.priors[level]
  2020. sampled_tokens = music_tokens[level]
  2021. n_ctx = prior.n_ctx
  2022. nb_sampled_tokens = sampled_tokens.shape[1]
  2023. if nb_sampled_tokens < n_ctx - tokens_to_sample:
  2024. sampling_kwargs["sample_tokens"] = nb_sampled_tokens + tokens_to_sample
  2025. start = 0
  2026. else:
  2027. sampling_kwargs["sample_tokens"] = n_ctx
  2028. start = nb_sampled_tokens - n_ctx + tokens_to_sample
  2029. return self.sample_single_window(music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size)
  2030. # Sample a single window of length=n_ctx at position=start on level=level
  2031. def sample_single_window(self, music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size):
  2032. prior = self.priors[level]
  2033. n_samples = music_tokens[0].shape[0]
  2034. n_ctx = prior.n_ctx
  2035. end = start + n_ctx
  2036. # get music_tokens already sampled at current level
  2037. previous_sampled_tokens = music_tokens[level][:, start:end]
  2038. sample_tokens = sampling_kwargs.get("sample_tokens", None)
  2039. if "sample_tokens" in sampling_kwargs:
  2040. sample_tokens = end - start
  2041. conditioning_tokens = previous_sampled_tokens.shape[1]
  2042. new_tokens = sample_tokens - previous_sampled_tokens.shape[1]
  2043. logger.info(
  2044. f"Sampling {sample_tokens} tokens for [{start},{start+sample_tokens}]. Conditioning on"
  2045. f" {conditioning_tokens} tokens"
  2046. )
  2047. if new_tokens <= 0:
  2048. # Nothing new to sample
  2049. return music_tokens
  2050. # get music_tokens_conds from level above
  2051. music_tokens_conds = prior.get_music_tokens_conds(music_tokens, start, end)
  2052. # if there are no levels above should return None!
  2053. # set metadata offset, sample_length and lyrics tokens
  2054. metadata = prior.get_metadata(labels, start, self.total_length, offset)
  2055. music_tokens_list = self.split_batch(previous_sampled_tokens, n_samples, max_batch_size)
  2056. music_tokens_conds_list = self.split_batch(music_tokens_conds, n_samples, max_batch_size)
  2057. metadata_list = self.split_batch(metadata, n_samples, max_batch_size)
  2058. tokens = []
  2059. iterator = tqdm(zip(music_tokens_list, music_tokens_conds_list, metadata_list), leave=False)
  2060. for music_tokens_i, music_tokens_conds_i, metadata_i in iterator:
  2061. name = ["Ancestral", "Primed"][music_tokens_i.shape[1] == 0]
  2062. iterator.set_description(
  2063. f"[prior level {level}] {name} Sampling {sample_tokens} tokens out of"
  2064. f" {self.total_length//prior.raw_to_tokens}",
  2065. refresh=True,
  2066. )
  2067. tokens_i = prior.sample(
  2068. n_samples=music_tokens_i.shape[0],
  2069. music_tokens=music_tokens_i,
  2070. music_tokens_conds=music_tokens_conds_i,
  2071. metadata=metadata_i,
  2072. **sampling_kwargs,
  2073. )
  2074. tokens.append(tokens_i)
  2075. sampled_tokens = torch.cat(tokens, dim=0)
  2076. # Update music_tokens with new sample
  2077. music_tokens_new = sampled_tokens[:, -new_tokens:]
  2078. music_tokens[level] = torch.cat([music_tokens[level], music_tokens_new], dim=1)
  2079. return music_tokens
  2080. # Sample total_length tokens at level=level with hop_length=hop_length
  2081. def sample_level(
  2082. self, music_tokens, labels, offset, sampling_kwargs, level, total_length, hop_length, max_batch_size
  2083. ):
  2084. if total_length >= self.priors[level].n_ctx:
  2085. iterator = get_starts(total_length, self.priors[level].n_ctx, hop_length)
  2086. for start in iterator:
  2087. music_tokens = self.sample_single_window(
  2088. music_tokens, labels, offset, sampling_kwargs, level, start, max_batch_size
  2089. )
  2090. else:
  2091. music_tokens = self.sample_partial_window(
  2092. music_tokens, labels, offset, sampling_kwargs, level, total_length, max_batch_size
  2093. )
  2094. return music_tokens
  2095. @torch.no_grad()
  2096. def _sample(
  2097. self,
  2098. music_tokens,
  2099. labels,
  2100. sample_levels,
  2101. metas=None,
  2102. chunk_size=32,
  2103. sampling_temperature=0.98,
  2104. lower_batch_size=16,
  2105. max_batch_size=16,
  2106. sample_length_in_seconds=24,
  2107. compute_alignments=False,
  2108. sample_tokens=None,
  2109. offset=0,
  2110. save_results=True,
  2111. sample_length=None,
  2112. ) -> List[torch.LongTensor]:
  2113. """
  2114. Core sampling function used to generate music tokens. Iterates over the provided list of levels, while saving
  2115. the generated raw audio at each step.
  2116. Args:
  2117. music_tokens (`List[torch.LongTensor]`):
  2118. A sequence of music tokens of length `self.levels` which will be used as context to continue the
  2119. sampling process. Should have `self.levels` tensors, each corresponding to the generation at a certain
  2120. level.
  2121. labels (`List[torch.LongTensor]`):
  2122. List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
  2123. lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
  2124. which are used to condition the generation.
  2125. sample_levels (`List[int]`):
  2126. List of the desired levels at which the sampling will be done. A level is equivalent to the index of
  2127. the prior in the list of priors
  2128. metas (`List[Any]`, *optional*):
  2129. Metadatas used to generate the `labels`
  2130. chunk_size (`int`, *optional*, defaults to 32):
  2131. Size of a chunk of audio, used to fill up the memory in chuncks to prevent OOM erros. Bigger chunks
  2132. means faster memory filling but more consumption.
  2133. sampling_temperature (`float`, *optional*, defaults to 0.98):
  2134. Temperature used to ajust the randomness of the sampling.
  2135. lower_batch_size (`int`, *optional*, defaults to 16):
  2136. Maximum batch size for the lower level priors
  2137. max_batch_size (`int`, *optional*, defaults to 16):
  2138. Maximum batch size for the top level priors
  2139. sample_length_in_seconds (`int`, *optional*, defaults to 24):
  2140. Desired length of the generation in seconds
  2141. compute_alignments (`bool`, *optional*, defaults to `False`):
  2142. Whether or not to compute the alignment between the lyrics and the audio using the top_prior
  2143. sample_tokens (`int`, *optional*):
  2144. Precise number of tokens that should be sampled at each level. This is mostly useful for running dummy
  2145. experiments
  2146. offset (`int`, *optional*, defaults to 0):
  2147. Audio offset used as conditioning, corresponds to the starting sample in the music. If the offset is
  2148. greater than 0, the lyrics will be shifted take that intoaccount
  2149. save_results (`bool`, *optional*, defaults to `True`):
  2150. Whether or not to save the intermediate results. If `True`, will generate a folder named with the start
  2151. time.
  2152. sample_length (`int`, *optional*):
  2153. Desired length of the generation in samples.
  2154. Returns: torch.Tensor
  2155. Example:
  2156. ```python
  2157. >>> from transformers import AutoTokenizer, JukeboxModel, set_seed
  2158. >>> import torch
  2159. >>> metas = dict(artist="Zac Brown Band", genres="Country", lyrics="I met a traveller from an antique land")
  2160. >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
  2161. >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
  2162. >>> labels = tokenizer(**metas)["input_ids"]
  2163. >>> set_seed(0)
  2164. >>> zs = [torch.zeros(1, 0, dtype=torch.long) for _ in range(3)]
  2165. >>> zs = model._sample(zs, labels, [0], sample_length=40 * model.priors[0].raw_to_tokens, save_results=False)
  2166. >>> zs[0]
  2167. tensor([[1853, 1369, 1150, 1869, 1379, 1789, 519, 710, 1306, 1100, 1229, 519,
  2168. 353, 1306, 1379, 1053, 519, 653, 1631, 1467, 1229, 1229, 10, 1647,
  2169. 1254, 1229, 1306, 1528, 1789, 216, 1631, 1434, 653, 475, 1150, 1528,
  2170. 1804, 541, 1804, 1434]])
  2171. ```
  2172. """
  2173. top_prior = self.priors[0]
  2174. if sample_length is not None:
  2175. total_length = sample_length
  2176. else:
  2177. total_length = (
  2178. int(sample_length_in_seconds * self.config.sampling_rate) // top_prior.raw_to_tokens
  2179. ) * top_prior.raw_to_tokens
  2180. if sample_levels is None:
  2181. sample_levels = range(len(self.priors))
  2182. # total length of the signal, might be bit different from the actual generated length
  2183. self.total_length = total_length
  2184. for level in sample_levels:
  2185. sampling_kwargs = {
  2186. "temp": 0.99 if level == len(self.priors) - 1 else sampling_temperature,
  2187. "chunk_size": chunk_size,
  2188. "sample_tokens": sample_tokens,
  2189. }
  2190. # Set correct total_length, hop_length, labels and sampling_kwargs for level
  2191. total_token_to_sample = total_length // self.priors[level].raw_to_tokens
  2192. hop_length = int(self.config.hop_fraction[level] * self.priors[level].n_ctx)
  2193. max_batch_size = lower_batch_size if level != sample_levels else max_batch_size
  2194. music_tokens = self.sample_level(
  2195. music_tokens,
  2196. labels[level],
  2197. offset,
  2198. sampling_kwargs,
  2199. level,
  2200. total_token_to_sample,
  2201. hop_length,
  2202. max_batch_size,
  2203. )
  2204. if save_results:
  2205. self.vqvae.to(music_tokens[level].device)
  2206. # Decode sample
  2207. with torch.no_grad():
  2208. start_level = len(self.priors) - level - 1 # vqvae levels are reversed
  2209. raw_audio = self.vqvae.decode(
  2210. music_tokens[: level + 1], start_level=start_level, bs_chunks=music_tokens[level].shape[0]
  2211. )
  2212. logdir = f"jukebox/level_{level}"
  2213. if not os.path.exists(logdir):
  2214. os.makedirs(logdir)
  2215. save_temp_audio(logdir, level, metas=metas, aud=raw_audio.float())
  2216. if compute_alignments and self.priors[0] is not None and self.priors[0].nb_relevant_lyric_tokens > 0:
  2217. with torch.no_grad():
  2218. alignments = get_alignment(music_tokens, labels[0], self.priors[0], self.config)
  2219. torch.save({"alignments": alignments}, f"{logdir}/lyric_alignments.pt")
  2220. return music_tokens
  2221. @add_start_docstrings(
  2222. """
  2223. Generates music tokens based on the provided `labels. Will start at the desired prior level and automatically
  2224. upsample the sequence. If you want to create the audio, you should call `model.decode(tokens)`, which will use
  2225. the VQ-VAE decoder to convert the music tokens to raw audio.
  2226. Args:
  2227. labels (`List[torch.LongTensor]`) :
  2228. List of length `n_sample`, and shape `(self.levels, 4 + self.config.max_nb_genre +
  2229. lyric_sequence_length)` metadata such as `artist_id`, `genre_id` and the full list of lyric tokens
  2230. which are used to condition the generation.
  2231. n_samples (`int`, *optional*, default to 1) :
  2232. Number of samples to be generated in parallel.
  2233. """,
  2234. )
  2235. def ancestral_sample(self, labels, n_samples=1, **sampling_kwargs) -> List[torch.LongTensor]:
  2236. """
  2237. Example:
  2238. ```python
  2239. >>> from transformers import AutoTokenizer, JukeboxModel, set_seed
  2240. >>> model = JukeboxModel.from_pretrained("openai/jukebox-1b-lyrics", min_duration=0).eval()
  2241. >>> tokenizer = AutoTokenizer.from_pretrained("openai/jukebox-1b-lyrics")
  2242. >>> lyrics = "Hey, are you awake? Can you talk to me?"
  2243. >>> artist = "Zac Brown Band"
  2244. >>> genre = "Country"
  2245. >>> metas = tokenizer(artist=artist, genres=genre, lyrics=lyrics)
  2246. >>> set_seed(0)
  2247. >>> music_tokens = model.ancestral_sample(metas.input_ids, sample_length=400)
  2248. >>> with torch.no_grad():
  2249. ... model.decode(music_tokens)[:, :10].squeeze(-1)
  2250. tensor([[-0.0219, -0.0679, -0.1050, -0.1203, -0.1271, -0.0936, -0.0396, -0.0405,
  2251. -0.0818, -0.0697]])
  2252. ```
  2253. """
  2254. sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
  2255. music_tokens = [
  2256. torch.zeros(n_samples, 0, dtype=torch.long, device=labels[0].device) for _ in range(len(self.priors))
  2257. ]
  2258. music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
  2259. return music_tokens
  2260. @add_start_docstrings(
  2261. """Generates a continuation of the previously generated tokens.
  2262. Args:
  2263. music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
  2264. A sequence of music tokens which will be used as context to continue the sampling process. Should have
  2265. `self.levels` tensors, each corresponding to the generation at a certain level.
  2266. """,
  2267. JUKEBOX_SAMPLING_INPUT_DOCSTRING,
  2268. )
  2269. def continue_sample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
  2270. sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
  2271. music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
  2272. return music_tokens
  2273. @add_start_docstrings(
  2274. """Upsamples a sequence of music tokens using the prior at level `level`.
  2275. Args:
  2276. music_tokens (`List[torch.LongTensor]` of length `self.levels` ) :
  2277. A sequence of music tokens which will be used as context to continue the sampling process. Should have
  2278. `self.levels` tensors, each corresponding to the generation at a certain level.
  2279. """,
  2280. JUKEBOX_SAMPLING_INPUT_DOCSTRING,
  2281. )
  2282. def upsample(self, music_tokens, labels, **sampling_kwargs) -> List[torch.LongTensor]:
  2283. sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors) - 1)))
  2284. music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
  2285. return music_tokens
  2286. @add_start_docstrings(
  2287. """Generate a raw audio conditioned on the provided `raw_audio` which is used as conditioning at each of the
  2288. generation levels. The audio is encoded to music tokens using the 3 levels of the VQ-VAE. These tokens are
  2289. used: as conditioning for each level, which means that no ancestral sampling is required.
  2290. Args:
  2291. raw_audio (`List[torch.Tensor]` of length `n_samples` ) :
  2292. A list of raw audio that will be used as conditioning information for each samples that will be
  2293. generated.
  2294. """,
  2295. JUKEBOX_SAMPLING_INPUT_DOCSTRING,
  2296. )
  2297. def primed_sample(self, raw_audio, labels, **sampling_kwargs) -> List[torch.LongTensor]:
  2298. sample_levels = sampling_kwargs.pop("sample_levels", list(range(len(self.priors))))
  2299. self.vqvae.to(raw_audio.device).float()
  2300. with torch.no_grad():
  2301. music_tokens = self.vqvae.encode(
  2302. raw_audio, start_level=0, end_level=len(self.priors), bs_chunks=raw_audio.shape[0]
  2303. )
  2304. music_tokens = self._sample(music_tokens, labels, sample_levels, **sampling_kwargs)
  2305. return music_tokens