modeling_moshi.py 132 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681
  1. # coding=utf-8
  2. # Copyright 2024 Kyutai and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Moshi model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Dict, List, Optional, Tuple, Union
  19. import torch
  20. import torch.nn as nn
  21. import torch.utils.checkpoint
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
  25. from ...generation import (
  26. GenerationConfig,
  27. GenerationMixin,
  28. )
  29. from ...modeling_attn_mask_utils import AttentionMaskConverter
  30. from ...modeling_outputs import (
  31. BaseModelOutputWithPast,
  32. CausalLMOutputWithPast,
  33. ModelOutput,
  34. Seq2SeqLMOutput,
  35. )
  36. from ...modeling_utils import PreTrainedModel
  37. from ...pytorch_utils import ALL_LAYERNORM_LAYERS
  38. from ...utils import (
  39. add_start_docstrings,
  40. add_start_docstrings_to_model_forward,
  41. is_flash_attn_2_available,
  42. is_flash_attn_greater_or_equal_2_10,
  43. is_torchdynamo_compiling,
  44. logging,
  45. replace_return_docstrings,
  46. )
  47. from ..auto.modeling_auto import AutoModel
  48. from .configuration_moshi import MoshiConfig, MoshiDepthConfig
  49. if is_flash_attn_2_available():
  50. from ...modeling_flash_attention_utils import _flash_attention_forward
  51. logger = logging.get_logger(__name__)
  52. _CONFIG_FOR_DOC = "MoshiConfig"
  53. _CHECKPOINT_FOR_DOC = "kmhf/hf-moshiko"
  54. @dataclass
  55. class MoshiConditionalGenerationGenerateOutput(ModelOutput):
  56. """
  57. Outputs of [`MoshiForConditionalConditionalGeneration.generate`].
  58. Args:
  59. audio_sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, 1, sequence_length)`, *optional*):
  60. The generated audio waveforms.
  61. sequences (`torch.LongTensor` of shape `(batch_size*num_return_sequences, sequence_length)`):
  62. The generated text sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  63. if all batches finished early due to the `eos_token_id`.
  64. sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True`):
  65. Final beam scores of the generated `sequences`.
  66. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  67. Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
  68. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
  69. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
  70. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
  71. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  72. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  73. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  74. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  75. beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True`):
  76. Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
  77. `(batch_size*num_return_sequences, sequence_length)`.
  78. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  79. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  80. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
  81. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  82. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  83. `torch.FloatTensor` of shape `(batch_size*num_beams*num_return_sequences, generated_length, hidden_size)`.
  84. past_key_values (`tuple(tuple(torch.FloatTensor)))`, *optional*, returned when `use_cache=True`):
  85. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  86. the model's documentation. Usually, a [`~cache_utils.Cache`] instance.
  87. audio_codes (`torch.LongTensor` of shape `(batch_size*num_return_sequences, num_codeooks, sequence_length)`, *optional*):
  88. The generated audio codes. Returned if `return_audio_codes=True`. Intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder.
  89. """
  90. audio_sequences: Optional[torch.Tensor] = None
  91. sequences: torch.LongTensor = None
  92. sequences_scores: Optional[torch.FloatTensor] = None
  93. scores: Optional[Tuple[torch.FloatTensor]] = None
  94. logits: Optional[Tuple[torch.FloatTensor]] = None
  95. beam_indices: Optional[torch.LongTensor] = None
  96. attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  97. hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  98. past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
  99. audio_codes: Optional[torch.LongTensor] = None
  100. @dataclass
  101. class MoshiCausalLMOutputWithPast(ModelOutput):
  102. """
  103. `MoshiForCausalLM` outputs.
  104. Args:
  105. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  106. Language modeling loss (for next-token prediction).
  107. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  108. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  109. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  110. Sequence of hidden-states at the output of the last layer of the model.
  111. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  112. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  113. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  114. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  115. `past_key_values` input) to speed up sequential decoding.
  116. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  117. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  118. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  119. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  120. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  121. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  122. sequence_length)`.
  123. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  124. heads.
  125. """
  126. loss: Optional[torch.FloatTensor] = None
  127. logits: torch.FloatTensor = None
  128. last_hidden_state: torch.FloatTensor = None
  129. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  130. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  131. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  132. @dataclass
  133. class MoshiConditionalGenerationOutputWithPast(ModelOutput):
  134. """
  135. `MoshiForConditionalGeneration` outputs.
  136. Args:
  137. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `text_labels` is provided):
  138. Text language modeling loss (for next-token prediction).
  139. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  140. Prediction scores of the text language modeling head (scores for each vocabulary token before SoftMax).
  141. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  142. Sequence of hidden-states at the output of the last layer of the model.
  143. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  144. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  145. `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
  146. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  147. `past_key_values` input) to speed up sequential decoding.
  148. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  149. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  150. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  151. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  152. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  153. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  154. sequence_length)`.
  155. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  156. heads.
  157. depth_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `audio_labels` is provided):
  158. Audio language modeling loss (for next-token prediction).
  159. audio_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  160. Prediction scores of the audio language modeling heads.
  161. depth_past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  162. Past key-values of the depth decoder.
  163. depth_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  164. Hidden states of the depth decoder
  165. depth_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  166. Depth decoder's Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  167. heads.
  168. """
  169. loss: Optional[torch.FloatTensor] = None
  170. logits: torch.FloatTensor = None
  171. last_hidden_state: torch.FloatTensor = None
  172. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  173. hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  174. attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  175. depth_loss: Optional[torch.FloatTensor] = None
  176. audio_logits: torch.FloatTensor = None
  177. depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  178. depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
  179. depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
  180. @dataclass
  181. class MoshiUnconditionalInput(ModelOutput):
  182. """
  183. Args:
  184. input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*):
  185. The sequence used as a text prompt for the generation.
  186. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  187. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  188. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  189. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  190. attention_mask (`torch.LongTensor`) of shape `(batch_size, sequence_length)`, *optional*):
  191. Attention mask to avoid performing attention on padding token indices. Mask values selected in `[0,
  192. 1]`: 1 for tokens that are **not masked**, 0 for tokens that are **masked**.
  193. """
  194. input_ids: torch.LongTensor = None
  195. user_audio_codes: torch.Tensor = None
  196. moshi_audio_codes: torch.Tensor = None
  197. attention_mask: torch.LongTensor = None
  198. # Copied from transformers.models.gemma.modeling_gemma.GemmaRMSNorm with Gemma->Moshi
  199. class MoshiRMSNorm(nn.Module):
  200. def __init__(self, dim: int, eps: float = 1e-6):
  201. super().__init__()
  202. self.eps = eps
  203. self.weight = nn.Parameter(torch.ones(dim)) # Ignore copy
  204. def _norm(self, x):
  205. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  206. # Ignore copy
  207. def forward(self, x):
  208. output = self._norm(x.float())
  209. output = output * self.weight.float()
  210. return output.type_as(x)
  211. def extra_repr(self):
  212. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  213. ALL_LAYERNORM_LAYERS.append(MoshiRMSNorm)
  214. class MoshiFlexibleLinear(nn.Module):
  215. def __init__(self, input_size, output_size, num_layers):
  216. super().__init__()
  217. # Stack the weights for N layers into a single tensor (num_layers, output_size, input_size)
  218. self.weight = nn.Parameter(torch.randn(num_layers, output_size, input_size))
  219. def forward(self, x, layer_idx=None):
  220. """
  221. `MoshiFlexibleLinear` creates one linear layer per codebook. There's multiple ways to use it.
  222. In the default case, `sequence_length=num_layers`, so each element of the sequence will be matmul to the weights corresponding to its index on the sequence.
  223. For more advanced cases, one can specify which codebook's layer(s) to use with `layer_idx`.
  224. If `layer_idx` indicates a single integer, all of the element of the sequence will be matmul to this single codebook's layer.
  225. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
  226. Args:
  227. x (`torch.FloatTensor): input to the layer of shape `(batch, num_layers, embed_dim)` or of shape `(batch, seq_length, embed_dim)`
  228. layer_idx (`torch.Tensor`, *optional*):
  229. Can be used to specify which codebook's layers(s) to use.
  230. If it's a tensor of shape `(seq_length,)`, will matmul each element of the sequence to the corresponding weights.
  231. But if `layer_idx` is a tensor of shape `(seq_length,)`, it will matmul each i-th element of the input sequence to the corresponding layer `weight[i]`.
  232. """
  233. # Use torch.gather to select the corresponding weights for each sample
  234. # (codebooks, output_size, hidden_size)
  235. selected_weights = torch.index_select(self.weight, 0, layer_idx) if layer_idx is not None else self.weight
  236. # (1, codebooks, hidden_size, output_size)
  237. selected_weights = selected_weights.transpose(1, 2)[None, :, :, :]
  238. # (batch_size, codebooks, 1, hidden_size) x (1, codebooks, hidden_size, output_size)
  239. # -> (batch_size, codebooks, 1, output_size)
  240. x = torch.matmul(x[:, :, None, :], selected_weights)
  241. # (batch_size, codebooks, output_size)
  242. return x.squeeze(2)
  243. class MoshiLinear(nn.Module):
  244. def __init__(self, input_dim, output_dim, num_codebooks, use_flexible_linear=False):
  245. super().__init__()
  246. self.use_flexible_linear = use_flexible_linear
  247. if not use_flexible_linear:
  248. self.linear = nn.Linear(input_dim, output_dim, bias=False)
  249. else:
  250. self.linear = MoshiFlexibleLinear(input_dim, output_dim, num_layers=num_codebooks)
  251. def forward(self, x, layer_idx=None):
  252. if self.use_flexible_linear:
  253. return self.linear(x, layer_idx)
  254. else:
  255. return self.linear(x)
  256. # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Moshi
  257. class MoshiRotaryEmbedding(nn.Module):
  258. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  259. super().__init__()
  260. self.dim = dim
  261. self.max_position_embeddings = max_position_embeddings
  262. self.base = base
  263. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
  264. self.register_buffer("inv_freq", inv_freq, persistent=False)
  265. @torch.no_grad()
  266. # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward
  267. # TODO(joao): add me back asap :)
  268. def forward(self, x, position_ids):
  269. # x: [bs, num_attention_heads, seq_len, head_size]
  270. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  271. position_ids_expanded = position_ids[:, None, :].float()
  272. # Force float32 since bfloat16 loses precision on long contexts
  273. # See https://github.com/huggingface/transformers/pull/29285
  274. device_type = x.device.type
  275. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  276. with torch.autocast(device_type=device_type, enabled=False):
  277. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  278. emb = torch.cat((freqs, freqs), dim=-1)
  279. cos = emb.cos()
  280. sin = emb.sin()
  281. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  282. # Copied from transformers.models.llama.modeling_llama.rotate_half
  283. def rotate_half(x):
  284. """Rotates half the hidden dims of the input."""
  285. x1 = x[..., : x.shape[-1] // 2]
  286. x2 = x[..., x.shape[-1] // 2 :]
  287. return torch.cat((-x2, x1), dim=-1)
  288. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  289. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  290. """Applies Rotary Position Embedding to the query and key tensors.
  291. Args:
  292. q (`torch.Tensor`): The query tensor.
  293. k (`torch.Tensor`): The key tensor.
  294. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  295. sin (`torch.Tensor`): The sine part of the rotary embedding.
  296. position_ids (`torch.Tensor`, *optional*):
  297. Deprecated and unused.
  298. unsqueeze_dim (`int`, *optional*, defaults to 1):
  299. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  300. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  301. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  302. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  303. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  304. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  305. Returns:
  306. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  307. """
  308. cos = cos.unsqueeze(unsqueeze_dim)
  309. sin = sin.unsqueeze(unsqueeze_dim)
  310. q_embed = (q * cos) + (rotate_half(q) * sin)
  311. k_embed = (k * cos) + (rotate_half(k) * sin)
  312. return q_embed, k_embed
  313. class MoshiGatingMLP(nn.Module):
  314. def __init__(self, config, use_flexible_linear=False):
  315. super().__init__()
  316. self.activation_fn = ACT2FN[config.hidden_act]
  317. ffn_dim = config.ffn_dim
  318. hidden_size = config.hidden_size
  319. num_layers = config.num_codebooks if use_flexible_linear else 1
  320. if num_layers == 1:
  321. self.fc1 = nn.Linear(hidden_size, ffn_dim, bias=False)
  322. self.fc2 = nn.Linear(ffn_dim // 2, hidden_size, bias=False)
  323. else:
  324. self.fc1 = MoshiFlexibleLinear(hidden_size, ffn_dim, num_layers)
  325. self.fc2 = MoshiFlexibleLinear(ffn_dim // 2, hidden_size, num_layers)
  326. def forward(self, hidden_states: torch.Tensor, layer_idx: int = None) -> torch.Tensor:
  327. hidden_states = self.fc1(hidden_states) if layer_idx is None else self.fc1(hidden_states, layer_idx)
  328. batch_size, sequence_length, _ = hidden_states.shape
  329. hidden_states = hidden_states.view(batch_size, sequence_length, 2, -1)
  330. hidden_states = self.activation_fn(hidden_states[..., 0, :]) * hidden_states[..., 1, :]
  331. hidden_states = self.fc2(hidden_states) if layer_idx is None else self.fc2(hidden_states, layer_idx)
  332. return hidden_states
  333. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  334. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  335. """
  336. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  337. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  338. """
  339. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  340. if n_rep == 1:
  341. return hidden_states
  342. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  343. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  344. class MoshiAttention(nn.Module):
  345. """Multi-headed attention from 'Attention Is All You Need' paper"""
  346. def __init__(self, config: MoshiConfig, layer_idx: Optional[int] = None, use_flexible_linear=False, use_rope=True):
  347. super().__init__()
  348. self.config = config
  349. self.layer_idx = layer_idx
  350. if layer_idx is None:
  351. logger.warning_once(
  352. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  353. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  354. "when creating this class."
  355. )
  356. self.attention_dropout = config.attention_dropout
  357. self.hidden_size = config.hidden_size
  358. self.num_heads = config.num_attention_heads
  359. self.head_dim = config.head_dim
  360. self.num_key_value_heads = config.num_key_value_heads
  361. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  362. self.max_position_embeddings = config.max_position_embeddings
  363. self.is_causal = True
  364. self.scaling = 1 / math.sqrt(self.head_dim)
  365. if self.hidden_size % self.num_heads != 0:
  366. raise ValueError(
  367. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  368. f" and `num_heads`: {self.num_heads})."
  369. )
  370. self.q_proj = MoshiLinear(
  371. self.hidden_size, self.num_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  372. )
  373. self.k_proj = MoshiLinear(
  374. self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  375. )
  376. self.v_proj = MoshiLinear(
  377. self.hidden_size, self.num_key_value_heads * self.head_dim, config.num_codebooks, use_flexible_linear
  378. )
  379. self.o_proj = MoshiLinear(
  380. self.num_heads * self.head_dim, self.hidden_size, config.num_codebooks, use_flexible_linear
  381. )
  382. # rotary embeddings are not used in the depth decoder
  383. self.rotary_emb = None
  384. if use_rope:
  385. self.rope_theta = config.rope_theta
  386. self.rotary_emb = MoshiRotaryEmbedding(
  387. self.head_dim,
  388. max_position_embeddings=self.max_position_embeddings,
  389. base=self.rope_theta,
  390. )
  391. # Copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward
  392. def forward(
  393. self,
  394. hidden_states: torch.Tensor,
  395. attention_mask: Optional[torch.Tensor] = None,
  396. position_ids: Optional[torch.LongTensor] = None,
  397. past_key_value: Optional[Cache] = None,
  398. output_attentions: bool = False,
  399. use_cache: bool = False,
  400. cache_position: Optional[torch.LongTensor] = None,
  401. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  402. bsz, q_len, _ = hidden_states.size()
  403. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  404. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  405. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  406. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  407. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  408. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  409. if self.rotary_emb is not None: # Ignore copy
  410. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  411. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  412. if past_key_value is not None:
  413. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  414. cache_kwargs = (
  415. {"sin": sin, "cos": cos, "cache_position": cache_position}
  416. if self.rotary_emb is not None
  417. else {"cache_position": cache_position}
  418. ) # Ignore copy
  419. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  420. key_states = repeat_kv(key_states, self.num_key_value_groups)
  421. value_states = repeat_kv(value_states, self.num_key_value_groups)
  422. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  423. if attention_mask is not None: # no matter the length, we just slice it
  424. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  425. attn_weights = attn_weights + causal_mask
  426. # upcast attention to fp32
  427. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  428. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  429. attn_output = torch.matmul(attn_weights, value_states)
  430. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  431. raise ValueError(
  432. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  433. f" {attn_output.size()}"
  434. )
  435. attn_output = attn_output.transpose(1, 2).contiguous()
  436. attn_output = attn_output.view(bsz, q_len, -1)
  437. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  438. if not output_attentions:
  439. attn_weights = None
  440. return attn_output, attn_weights, past_key_value
  441. # Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Moshi
  442. class MoshiFlashAttention2(MoshiAttention):
  443. """
  444. Moshi flash attention module. This module inherits from `MoshiAttention` as the weights of the module stays
  445. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  446. flash attention and deal with padding tokens in case the input contains any of them.
  447. """
  448. def __init__(self, *args, **kwargs):
  449. super().__init__(*args, **kwargs)
  450. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  451. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  452. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  453. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  454. def forward(
  455. self,
  456. hidden_states: torch.Tensor,
  457. attention_mask: Optional[torch.LongTensor] = None,
  458. position_ids: Optional[torch.LongTensor] = None,
  459. past_key_value: Optional[Cache] = None,
  460. output_attentions: bool = False,
  461. use_cache: bool = False,
  462. cache_position: Optional[torch.LongTensor] = None,
  463. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  464. if isinstance(past_key_value, StaticCache):
  465. raise ValueError(
  466. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  467. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  468. )
  469. output_attentions = False
  470. bsz, q_len, _ = hidden_states.size()
  471. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  472. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  473. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  474. # Flash attention requires the input to have the shape
  475. # batch_size x seq_length x head_dim x hidden_dim
  476. # therefore we just need to keep the original shape
  477. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  478. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  479. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  480. if self.rotary_emb is not None: # Ignore copy
  481. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  482. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  483. if past_key_value is not None:
  484. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  485. cache_kwargs = (
  486. {"sin": sin, "cos": cos, "cache_position": cache_position}
  487. if self.rotary_emb is not None
  488. else {"cache_position": cache_position}
  489. ) # Ignore copy
  490. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  491. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  492. # to be able to avoid many of these transpose/reshape/view.
  493. query_states = query_states.transpose(1, 2)
  494. key_states = key_states.transpose(1, 2)
  495. value_states = value_states.transpose(1, 2)
  496. dropout_rate = self.attention_dropout if self.training else 0.0
  497. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  498. # therefore the input hidden states gets silently casted in float32. Hence, we need
  499. # cast them back in the correct dtype just to be sure everything works as expected.
  500. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  501. # in fp32. (MoshiRMSNorm handles it correctly)
  502. input_dtype = query_states.dtype
  503. if input_dtype == torch.float32:
  504. if torch.is_autocast_enabled():
  505. target_dtype = torch.get_autocast_gpu_dtype()
  506. # Handle the case where the model is quantized
  507. elif hasattr(self.config, "_pre_quantization_dtype"):
  508. target_dtype = self.config._pre_quantization_dtype
  509. else:
  510. target_dtype = self.q_proj.weight.dtype
  511. logger.warning_once(
  512. f"The input hidden states seems to be silently casted in float32, this might be related to"
  513. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  514. f" {target_dtype}."
  515. )
  516. query_states = query_states.to(target_dtype)
  517. key_states = key_states.to(target_dtype)
  518. value_states = value_states.to(target_dtype)
  519. attn_output = _flash_attention_forward(
  520. query_states,
  521. key_states,
  522. value_states,
  523. attention_mask,
  524. q_len,
  525. position_ids=position_ids,
  526. dropout=dropout_rate,
  527. sliding_window=getattr(self, "sliding_window", None),
  528. is_causal=self.is_causal,
  529. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  530. )
  531. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  532. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  533. if not output_attentions:
  534. attn_weights = None
  535. return attn_output, attn_weights, past_key_value
  536. # Copied from transformers.models.gemma.modeling_gemma.GemmaSdpaAttention with Gemma->Moshi
  537. class MoshiSdpaAttention(MoshiAttention):
  538. """
  539. Moshi attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  540. `MoshiAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  541. SDPA API.
  542. """
  543. # Adapted from MoshiAttention.forward
  544. def forward(
  545. self,
  546. hidden_states: torch.Tensor,
  547. attention_mask: Optional[torch.Tensor] = None,
  548. position_ids: Optional[torch.LongTensor] = None,
  549. past_key_value: Optional[Cache] = None,
  550. output_attentions: bool = False,
  551. use_cache: bool = False,
  552. cache_position: Optional[torch.LongTensor] = None,
  553. **kwargs,
  554. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  555. if output_attentions:
  556. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  557. logger.warning_once(
  558. "MoshiModel is using MoshiSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  559. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  560. )
  561. return super().forward(
  562. hidden_states=hidden_states,
  563. attention_mask=attention_mask,
  564. position_ids=position_ids,
  565. past_key_value=past_key_value,
  566. output_attentions=output_attentions,
  567. use_cache=use_cache,
  568. cache_position=cache_position,
  569. )
  570. bsz, q_len, _ = hidden_states.size()
  571. query_states = self.q_proj(hidden_states, cache_position) # Ignore copy
  572. key_states = self.k_proj(hidden_states, cache_position) # Ignore copy
  573. value_states = self.v_proj(hidden_states, cache_position) # Ignore copy
  574. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  575. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  576. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  577. if self.rotary_emb is not None: # Ignore copy
  578. cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy
  579. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy
  580. if past_key_value is not None:
  581. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  582. cache_kwargs = (
  583. {"sin": sin, "cos": cos, "cache_position": cache_position}
  584. if self.rotary_emb is not None
  585. else {"cache_position": cache_position}
  586. ) # Ignore copy
  587. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  588. key_states = repeat_kv(key_states, self.num_key_value_groups)
  589. value_states = repeat_kv(value_states, self.num_key_value_groups)
  590. causal_mask = attention_mask
  591. if attention_mask is not None:
  592. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  593. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  594. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  595. if query_states.device.type == "cuda" and causal_mask is not None:
  596. query_states = query_states.contiguous()
  597. key_states = key_states.contiguous()
  598. value_states = value_states.contiguous()
  599. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  600. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  601. is_causal = True if causal_mask is None and q_len > 1 else False
  602. attn_output = torch.nn.functional.scaled_dot_product_attention(
  603. query_states,
  604. key_states,
  605. value_states,
  606. attn_mask=causal_mask,
  607. dropout_p=self.attention_dropout if self.training else 0.0,
  608. is_causal=is_causal,
  609. )
  610. attn_output = attn_output.transpose(1, 2).contiguous()
  611. attn_output = attn_output.view(bsz, q_len, -1)
  612. attn_output = self.o_proj(attn_output, cache_position) # Ignore copy
  613. return attn_output, None, past_key_value
  614. MOSHI_ATTENTION_CLASSES = {
  615. "eager": MoshiAttention,
  616. "flash_attention_2": MoshiFlashAttention2,
  617. "sdpa": MoshiSdpaAttention,
  618. }
  619. class MoshiDecoderLayer(nn.Module):
  620. def __init__(self, config: MoshiConfig, layer_idx: int, use_flexible_linear: bool, use_rope=True):
  621. super().__init__()
  622. self.hidden_size = config.hidden_size
  623. self.use_flexible_linear = use_flexible_linear
  624. self.self_attn = MOSHI_ATTENTION_CLASSES[config._attn_implementation](
  625. config=config, layer_idx=layer_idx, use_flexible_linear=use_flexible_linear, use_rope=use_rope
  626. )
  627. self.mlp = MoshiGatingMLP(config, use_flexible_linear)
  628. self.input_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  629. self.post_attention_layernorm = MoshiRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  630. self.sliding_window = config.sliding_window
  631. self._attn_implementation = config._attn_implementation
  632. def forward(
  633. self,
  634. hidden_states: torch.Tensor,
  635. attention_mask: Optional[torch.Tensor] = None,
  636. position_ids: Optional[torch.LongTensor] = None,
  637. past_key_value: Optional[Cache] = None,
  638. output_attentions: Optional[bool] = False,
  639. use_cache: Optional[bool] = False,
  640. cache_position: Optional[torch.LongTensor] = None,
  641. **kwargs,
  642. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  643. """
  644. Args:
  645. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  646. attention_mask (`torch.FloatTensor`, *optional*):
  647. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  648. query_sequence_length, key_sequence_length)` if default attention is used.
  649. output_attentions (`bool`, *optional*):
  650. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  651. returned tensors for more detail.
  652. use_cache (`bool`, *optional*):
  653. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  654. (see `past_key_values`).
  655. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  656. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  657. Indices depicting the position of the input sequence tokens in the sequence
  658. kwargs (`dict`, *optional*):
  659. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  660. into the model
  661. """
  662. residual = hidden_states
  663. hidden_states = self.input_layernorm(hidden_states)
  664. # Self Attention
  665. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  666. hidden_states=hidden_states,
  667. attention_mask=attention_mask,
  668. position_ids=position_ids,
  669. past_key_value=past_key_value,
  670. output_attentions=output_attentions,
  671. use_cache=use_cache,
  672. cache_position=cache_position,
  673. **kwargs,
  674. )
  675. hidden_states = residual + hidden_states
  676. # Fully Connected
  677. residual = hidden_states
  678. hidden_states = self.post_attention_layernorm(hidden_states)
  679. hidden_states = (
  680. self.mlp(hidden_states) if not self.use_flexible_linear else self.mlp(hidden_states, cache_position)
  681. )
  682. hidden_states = residual + hidden_states
  683. outputs = (hidden_states,)
  684. if output_attentions:
  685. outputs += (self_attn_weights,)
  686. if use_cache:
  687. outputs += (present_key_value,)
  688. return outputs
  689. class MoshiPreTrainedModel(PreTrainedModel):
  690. """
  691. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  692. models.
  693. """
  694. config_class = MoshiConfig
  695. base_model_prefix = "model"
  696. supports_gradient_checkpointing = True
  697. _no_split_modules = ["MoshiDecoderLayer", "MimiTransformerLayer"]
  698. _supports_flash_attn_2 = True
  699. _supports_sdpa = True
  700. _supports_cache_class = True
  701. main_input_name = "input_ids"
  702. def _init_weights(self, module):
  703. std = self.config.initializer_range
  704. if isinstance(module, (nn.Linear, nn.Conv1d)):
  705. module.weight.data.normal_(mean=0.0, std=std)
  706. if module.bias is not None:
  707. module.bias.data.zero_()
  708. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  709. module.bias.data.zero_()
  710. module.weight.data.fill_(1.0)
  711. elif isinstance(module, nn.Conv1d):
  712. nn.init.kaiming_normal_(module.weight)
  713. if module.bias is not None:
  714. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  715. nn.init.uniform_(module.bias, a=-k, b=k)
  716. elif isinstance(module, nn.Embedding):
  717. module.weight.data.normal_(mean=0.0, std=std)
  718. if module.padding_idx is not None:
  719. module.weight.data[module.padding_idx].zero_()
  720. MOSHI_START_DOCSTRING = r"""
  721. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  722. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  723. etc.)
  724. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  725. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  726. and behavior.
  727. Parameters:
  728. config ([`MoshiConfig`]): Model configuration class with all the parameters of the model.
  729. Initializing with a config file does not load the weights associated with the model, only the
  730. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  731. """
  732. MOSHI_INPUTS_DOCSTRING = r"""
  733. Args:
  734. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  735. Indices of input sequence text tokens in the vocabulary. Padding will be ignored by default should you provide
  736. it.
  737. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  738. [`PreTrainedTokenizer.__call__`] for details.
  739. [What are input IDs?](../glossary#input-ids)
  740. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  741. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  742. - 1 for tokens that are **not masked**,
  743. - 0 for tokens that are **masked**.
  744. [What are attention masks?](../glossary#attention-mask)
  745. user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  746. The audio waveforms used as audio user prompt for the generation.
  747. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  748. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  749. moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  750. The audio waveforms used as audio Moshi prompt for the generation.
  751. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  752. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  753. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  754. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded
  755. representation. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be
  756. input (see `past_key_values`). This is useful if you want more control over how to convert
  757. `input_ids` indices into associated vectors than the model's internal embedding lookup matrix.
  758. If `input_ids` and `inputs_embeds` are both unset, `inputs_embeds` takes the value
  759. of `inputs_embeds`.
  760. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  761. Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  762. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  763. Two formats are allowed:
  764. - a [`~cache_utils.Cache`] instance;
  765. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  766. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  767. cache format.
  768. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  769. legacy cache format will be returned.
  770. text_labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  771. Labels for text language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  772. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  773. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  774. audio_labels (`torch.LongTensor` of shape `(batch_size, num_codebooks, sequence_length)`, *optional*):
  775. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  776. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  777. are ignored (masked), the loss is only computed for labels in `[0, ..., config.audio_vocab_size]`
  778. use_cache (`bool`, *optional*):
  779. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  780. `past_key_values`).
  781. output_attentions (`bool`, *optional*):
  782. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  783. tensors for more detail.
  784. output_hidden_states (`bool`, *optional*):
  785. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  786. more detail.
  787. return_dict (`bool`, *optional*):
  788. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  789. """
  790. MOSHI_DECODER_INPUTS_DOCSTRING = r"""
  791. Args:
  792. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  793. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  794. it.
  795. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  796. [`PreTrainedTokenizer.__call__`] for details.
  797. [What are input IDs?](../glossary#input-ids)
  798. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  799. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  800. - 1 for tokens that are **not masked**,
  801. - 0 for tokens that are **masked**.
  802. [What are attention masks?](../glossary#attention-mask)
  803. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  804. [`PreTrainedTokenizer.__call__`] for details.
  805. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  806. `past_key_values`).
  807. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  808. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  809. information on the default strategy.
  810. - 1 indicates the head is **not masked**,
  811. - 0 indicates the head is **masked**.
  812. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  813. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  814. config.n_positions - 1]`.
  815. [What are position IDs?](../glossary#position-ids)
  816. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  817. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  818. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  819. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  820. Two formats are allowed:
  821. - a [`~cache_utils.Cache`] instance, see our
  822. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  823. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  824. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  825. cache format.
  826. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  827. legacy cache format will be returned.
  828. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  829. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  830. of shape `(batch_size, sequence_length)`.
  831. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  832. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  833. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  834. model's internal embedding lookup matrix.
  835. use_cache (`bool`, *optional*):
  836. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  837. `past_key_values`).
  838. output_attentions (`bool`, *optional*):
  839. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  840. tensors for more detail.
  841. output_hidden_states (`bool`, *optional*):
  842. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  843. more detail.
  844. return_dict (`bool`, *optional*):
  845. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  846. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  847. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  848. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  849. the complete sequence length.
  850. """
  851. class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin):
  852. """
  853. Transformer depth decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiTransformerLayer`]
  854. Args:
  855. config: MoshiConfig
  856. """
  857. config_class = MoshiDepthConfig
  858. def __init__(self, config: MoshiDepthConfig):
  859. super().__init__(config)
  860. self.text_embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size)
  861. # the last codebook is never used as input
  862. self.embed_tokens = nn.ModuleList(
  863. [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(config.num_codebooks - 1)]
  864. )
  865. self.input_projections = MoshiFlexibleLinear(config.input_size, config.hidden_size, config.num_codebooks)
  866. self.layers = nn.ModuleList(
  867. [
  868. MoshiDecoderLayer(config, layer_idx, use_flexible_linear=True, use_rope=False)
  869. for layer_idx in range(config.num_hidden_layers)
  870. ]
  871. )
  872. self.lm_heads = MoshiFlexibleLinear(config.hidden_size, config.audio_vocab_size, config.num_codebooks)
  873. self._attn_implementation = config._attn_implementation
  874. self.gradient_checkpointing = False
  875. self.config = config
  876. def forward(
  877. self,
  878. input_ids: Optional[torch.LongTensor] = None,
  879. last_hidden_state: torch.LongTensor = None,
  880. attention_mask: Optional[torch.BoolTensor] = None,
  881. past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
  882. inputs_embeds: Optional[torch.FloatTensor] = None,
  883. use_cache: Optional[bool] = None,
  884. output_attentions: Optional[bool] = None,
  885. output_hidden_states: Optional[bool] = None,
  886. return_dict: Optional[bool] = None,
  887. position_ids: Optional[torch.LongTensor] = None,
  888. labels: Optional[torch.LongTensor] = None,
  889. cache_position: Optional[torch.LongTensor] = None,
  890. ) -> Union[Tuple, BaseModelOutputWithPast]:
  891. """
  892. Args:
  893. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  894. Indices of input sequence tokens. The first element of the sequence must the text token associated to the audio codebooks.
  895. The rest of the elements must be flatten audio codebooks. The `cache_position` argument can be used to indicate to which index is associated each token.
  896. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  897. Sequence of hidden-states at the output of the last layer of the main decoder. Used to contextualize `input_ids`
  898. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  899. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  900. - 1 for tokens that are **not masked**,
  901. - 0 for tokens that are **masked**.
  902. [What are attention masks?](../glossary#attention-mask)
  903. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  904. [`PreTrainedTokenizer.__call__`] for details.
  905. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  906. `past_key_values`).
  907. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  908. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  909. information on the default strategy.
  910. - 1 indicates the head is **not masked**,
  911. - 0 indicates the head is **masked**.
  912. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  913. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  914. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  915. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  916. Two formats are allowed:
  917. - a [`~cache_utils.Cache`] instance;
  918. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  919. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  920. cache format.
  921. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  922. legacy cache format will be returned.
  923. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  924. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  925. of shape `(batch_size, sequence_length)`.
  926. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  927. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  928. is useful if you want more control over how to convert the inputs into associated vectors than the
  929. model's internal embedding lookup matrix.
  930. use_cache (`bool`, *optional*):
  931. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  932. `past_key_values`).
  933. output_attentions (`bool`, *optional*):
  934. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  935. tensors for more detail.
  936. output_hidden_states (`bool`, *optional*):
  937. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  938. more detail.
  939. return_dict (`bool`, *optional*):
  940. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  941. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  942. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  943. config.n_positions - 1]`.
  944. [What are position IDs?](../glossary#position-ids)
  945. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  946. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  947. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  948. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  949. cache_position (`torch.Tensor`):
  950. Indices depicting the position of the input sequence tokens in the sequence.
  951. """
  952. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  953. output_hidden_states = (
  954. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  955. )
  956. use_cache = use_cache if use_cache is not None else self.config.use_cache
  957. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  958. if self.gradient_checkpointing and self.training and use_cache:
  959. logger.warning_once(
  960. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  961. )
  962. use_cache = False
  963. if use_cache and past_key_values is None and not self.training:
  964. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  965. past_seen_tokens = 0 if past_key_values is None else past_key_values.get_seq_length()
  966. if cache_position is None:
  967. cache_position = torch.arange(
  968. past_seen_tokens, past_seen_tokens + input_ids.shape[1], device=input_ids.device
  969. )
  970. if position_ids is None:
  971. position_ids = cache_position.unsqueeze(0)
  972. # If inputs_embeds is provided, it has the priority over input_ids, which won't be used
  973. if inputs_embeds is None:
  974. inputs_embeds = []
  975. for position_idx in cache_position:
  976. position_idx = position_idx.item()
  977. if position_idx == 0:
  978. inputs_embeds.append(self.text_embed_tokens(input_ids[:, [position_idx]]))
  979. else:
  980. inputs_embeds.append(
  981. self.embed_tokens[(position_idx - 1)](input_ids[:, [position_idx - past_seen_tokens]])
  982. )
  983. inputs_embeds = torch.cat(inputs_embeds, dim=1)
  984. inputs_embeds += self.input_projections(last_hidden_state, cache_position)
  985. causal_mask = None
  986. if attention_mask is not None:
  987. causal_mask = self._update_causal_mask(
  988. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  989. )
  990. # decoder layers
  991. all_hidden_states = () if output_hidden_states else None
  992. all_self_attns = () if output_attentions else None
  993. next_decoder_cache = None
  994. hidden_states = inputs_embeds
  995. for decoder_layer in self.layers:
  996. if output_hidden_states:
  997. all_hidden_states += (hidden_states,)
  998. if self.gradient_checkpointing and self.training:
  999. layer_outputs = self._gradient_checkpointing_func(
  1000. decoder_layer.__call__,
  1001. hidden_states,
  1002. causal_mask,
  1003. position_ids,
  1004. past_key_values,
  1005. output_attentions,
  1006. use_cache,
  1007. cache_position,
  1008. )
  1009. else:
  1010. layer_outputs = decoder_layer(
  1011. hidden_states,
  1012. attention_mask=causal_mask,
  1013. position_ids=position_ids,
  1014. past_key_value=past_key_values,
  1015. output_attentions=output_attentions,
  1016. use_cache=use_cache,
  1017. cache_position=cache_position,
  1018. )
  1019. hidden_states = layer_outputs[0]
  1020. if use_cache:
  1021. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  1022. if output_attentions:
  1023. all_self_attns += (layer_outputs[1],)
  1024. # add hidden states from the last decoder layer
  1025. if output_hidden_states:
  1026. all_hidden_states += (hidden_states,)
  1027. next_cache = next_decoder_cache if use_cache else None
  1028. logits = self.lm_heads(hidden_states, cache_position)
  1029. loss = None
  1030. if labels is not None:
  1031. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1032. logits = logits.float()
  1033. loss_fct = CrossEntropyLoss()
  1034. labels = labels.masked_fill(labels == self.config.audio_vocab_size, -100).reshape(-1)
  1035. # Enable model parallelism
  1036. labels = labels.to(logits.device)
  1037. loss = loss_fct(logits.reshape(-1, self.config.audio_vocab_size), labels)
  1038. if not return_dict:
  1039. return tuple(v for v in [loss, logits, next_cache, all_hidden_states, all_self_attns] if v is not None)
  1040. return CausalLMOutputWithPast(
  1041. loss=loss,
  1042. logits=logits,
  1043. past_key_values=next_cache,
  1044. hidden_states=all_hidden_states,
  1045. attentions=all_self_attns,
  1046. )
  1047. # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
  1048. def _update_causal_mask(
  1049. self,
  1050. attention_mask: torch.Tensor,
  1051. input_tensor: torch.Tensor,
  1052. cache_position: torch.Tensor,
  1053. past_key_values: Cache,
  1054. output_attentions: bool,
  1055. ):
  1056. if self.config._attn_implementation == "flash_attention_2":
  1057. if attention_mask is not None and 0.0 in attention_mask:
  1058. return attention_mask
  1059. return None
  1060. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1061. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1062. # to infer the attention mask.
  1063. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1064. using_static_cache = isinstance(past_key_values, StaticCache)
  1065. using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
  1066. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1067. if (
  1068. self.config._attn_implementation == "sdpa"
  1069. and not (using_static_cache or using_sliding_window_cache)
  1070. and not output_attentions
  1071. ):
  1072. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1073. attention_mask,
  1074. inputs_embeds=input_tensor,
  1075. past_key_values_length=past_seen_tokens,
  1076. sliding_window=self.config.sliding_window,
  1077. is_training=self.training,
  1078. ):
  1079. return None
  1080. dtype, device = input_tensor.dtype, input_tensor.device
  1081. min_dtype = torch.finfo(dtype).min
  1082. sequence_length = input_tensor.shape[1]
  1083. # SlidingWindowCache or StaticCache
  1084. if using_sliding_window_cache or using_static_cache:
  1085. target_length = past_key_values.get_max_cache_shape()
  1086. # DynamicCache or no cache
  1087. else:
  1088. target_length = (
  1089. attention_mask.shape[-1]
  1090. if isinstance(attention_mask, torch.Tensor)
  1091. else past_seen_tokens + sequence_length + 1
  1092. )
  1093. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1094. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1095. attention_mask,
  1096. sequence_length=sequence_length,
  1097. target_length=target_length,
  1098. dtype=dtype,
  1099. device=device,
  1100. cache_position=cache_position,
  1101. batch_size=input_tensor.shape[0],
  1102. config=self.config,
  1103. past_key_values=past_key_values,
  1104. )
  1105. if (
  1106. self.config._attn_implementation == "sdpa"
  1107. and attention_mask is not None
  1108. and attention_mask.device.type == "cuda"
  1109. and not output_attentions
  1110. ):
  1111. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1112. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1113. # Details: https://github.com/pytorch/pytorch/issues/110213
  1114. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1115. return causal_mask
  1116. @staticmethod
  1117. # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->MoshiDepth
  1118. def _prepare_4d_causal_attention_mask_with_cache_position(
  1119. attention_mask: torch.Tensor,
  1120. sequence_length: int,
  1121. target_length: int,
  1122. dtype: torch.dtype,
  1123. device: torch.device,
  1124. cache_position: torch.Tensor,
  1125. batch_size: int,
  1126. config: MoshiDepthConfig,
  1127. past_key_values: Cache,
  1128. ):
  1129. """
  1130. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1131. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1132. Args:
  1133. attention_mask (`torch.Tensor`):
  1134. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  1135. sequence_length (`int`):
  1136. The sequence length being processed.
  1137. target_length (`int`):
  1138. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  1139. dtype (`torch.dtype`):
  1140. The dtype to use for the 4D attention mask.
  1141. device (`torch.device`):
  1142. The device to plcae the 4D attention mask on.
  1143. cache_position (`torch.Tensor`):
  1144. Indices depicting the position of the input sequence tokens in the sequence.
  1145. batch_size (`torch.Tensor`):
  1146. Batch size.
  1147. config (`MoshiDepthConfig`):
  1148. The model's configuration class
  1149. past_key_values (`Cache`):
  1150. The cache class that is being used currently to generate
  1151. """
  1152. if attention_mask is not None and attention_mask.dim() == 4:
  1153. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1154. causal_mask = attention_mask
  1155. else:
  1156. min_dtype = torch.finfo(dtype).min
  1157. causal_mask = torch.full(
  1158. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1159. )
  1160. diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1161. if config.sliding_window is not None:
  1162. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  1163. # the check is needed to verify is current checkpoint was trained with sliding window or not
  1164. if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
  1165. sliding_attend_mask = torch.arange(target_length, device=device) <= (
  1166. cache_position.reshape(-1, 1) - config.sliding_window
  1167. )
  1168. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  1169. causal_mask *= diagonal_attend_mask
  1170. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1171. if attention_mask is not None:
  1172. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1173. if attention_mask.shape[-1] > target_length:
  1174. attention_mask = attention_mask[:, :target_length]
  1175. mask_length = attention_mask.shape[-1]
  1176. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1177. padding_mask = padding_mask == 0
  1178. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1179. padding_mask, min_dtype
  1180. )
  1181. return causal_mask
  1182. @add_start_docstrings(
  1183. "The bare Moshi Model outputting raw hidden-states without any specific head on top.",
  1184. MOSHI_START_DOCSTRING,
  1185. )
  1186. class MoshiModel(MoshiPreTrainedModel):
  1187. """
  1188. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoshiDecoderLayer`]
  1189. Args:
  1190. config: MoshiConfig
  1191. """
  1192. def __init__(self, config: MoshiConfig):
  1193. super().__init__(config)
  1194. self.padding_idx = config.pad_token_id
  1195. self.vocab_size = config.vocab_size
  1196. self.embed_tokens = nn.Embedding(config.vocab_size + 1, config.hidden_size, self.padding_idx)
  1197. self.layers = nn.ModuleList(
  1198. [
  1199. MoshiDecoderLayer(config, layer_idx, use_flexible_linear=False)
  1200. for layer_idx in range(config.num_hidden_layers)
  1201. ]
  1202. )
  1203. self.norm = MoshiRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1204. self.gradient_checkpointing = False
  1205. # Initialize weights and apply final processing
  1206. self.post_init()
  1207. def get_input_embeddings(self):
  1208. return self.embed_tokens
  1209. def set_input_embeddings(self, value):
  1210. self.embed_tokens = value
  1211. @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING)
  1212. def forward(
  1213. self,
  1214. input_ids: torch.LongTensor = None,
  1215. attention_mask: Optional[torch.Tensor] = None,
  1216. position_ids: Optional[torch.LongTensor] = None,
  1217. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1218. inputs_embeds: Optional[torch.FloatTensor] = None,
  1219. use_cache: Optional[bool] = None,
  1220. output_attentions: Optional[bool] = None,
  1221. output_hidden_states: Optional[bool] = None,
  1222. return_dict: Optional[bool] = None,
  1223. cache_position: Optional[torch.LongTensor] = None,
  1224. ) -> Union[Tuple, BaseModelOutputWithPast]:
  1225. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1226. output_hidden_states = (
  1227. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1228. )
  1229. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1230. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1231. if self.gradient_checkpointing and self.training and use_cache:
  1232. logger.warning_once(
  1233. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  1234. )
  1235. use_cache = False
  1236. if inputs_embeds is None:
  1237. inputs_embeds = self.embed_tokens(input_ids)
  1238. return_legacy_cache = False # noqa: F841
  1239. if (
  1240. use_cache and not isinstance(past_key_values, Cache) and not self.training
  1241. ): # kept for BC (non `Cache` `past_key_values` inputs)
  1242. return_legacy_cache = True # noqa: F841
  1243. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  1244. if cache_position is None:
  1245. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1246. cache_position = torch.arange(
  1247. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  1248. )
  1249. if position_ids is None:
  1250. position_ids = cache_position.unsqueeze(0)
  1251. causal_mask = None
  1252. if attention_mask is not None:
  1253. causal_mask = self._update_causal_mask(
  1254. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  1255. )
  1256. # embed positions
  1257. hidden_states = inputs_embeds
  1258. if (
  1259. use_cache and not isinstance(past_key_values, Cache) and not self.training
  1260. ): # kept for BC (non `Cache` `past_key_values` inputs)
  1261. return_legacy_cache = True
  1262. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  1263. logger.warning_once(
  1264. "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
  1265. "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
  1266. )
  1267. # decoder layers
  1268. all_hidden_states = () if output_hidden_states else None
  1269. all_self_attns = () if output_attentions else None
  1270. next_decoder_cache = None
  1271. for decoder_layer in self.layers:
  1272. if output_hidden_states:
  1273. all_hidden_states += (hidden_states,)
  1274. if self.gradient_checkpointing and self.training:
  1275. layer_outputs = self._gradient_checkpointing_func(
  1276. decoder_layer.__call__,
  1277. hidden_states,
  1278. causal_mask,
  1279. position_ids,
  1280. past_key_values,
  1281. output_attentions,
  1282. use_cache,
  1283. cache_position,
  1284. )
  1285. else:
  1286. layer_outputs = decoder_layer(
  1287. hidden_states,
  1288. attention_mask=causal_mask,
  1289. position_ids=position_ids,
  1290. past_key_value=past_key_values,
  1291. output_attentions=output_attentions,
  1292. use_cache=use_cache,
  1293. cache_position=cache_position,
  1294. )
  1295. hidden_states = layer_outputs[0]
  1296. if use_cache:
  1297. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  1298. if output_attentions:
  1299. all_self_attns += (layer_outputs[1],)
  1300. hidden_states = self.norm(hidden_states)
  1301. # add hidden states from the last decoder layer
  1302. if output_hidden_states:
  1303. all_hidden_states += (hidden_states,)
  1304. next_cache = next_decoder_cache if use_cache else None
  1305. if return_legacy_cache:
  1306. next_cache = next_cache.to_legacy_cache()
  1307. if not return_dict:
  1308. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  1309. return BaseModelOutputWithPast(
  1310. last_hidden_state=hidden_states,
  1311. past_key_values=next_cache,
  1312. hidden_states=all_hidden_states,
  1313. attentions=all_self_attns,
  1314. )
  1315. # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
  1316. def _update_causal_mask(
  1317. self,
  1318. attention_mask: torch.Tensor,
  1319. input_tensor: torch.Tensor,
  1320. cache_position: torch.Tensor,
  1321. past_key_values: Cache,
  1322. output_attentions: bool,
  1323. ):
  1324. if self.config._attn_implementation == "flash_attention_2":
  1325. if attention_mask is not None and 0.0 in attention_mask:
  1326. return attention_mask
  1327. return None
  1328. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1329. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1330. # to infer the attention mask.
  1331. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1332. using_static_cache = isinstance(past_key_values, StaticCache)
  1333. using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
  1334. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1335. if (
  1336. self.config._attn_implementation == "sdpa"
  1337. and not (using_static_cache or using_sliding_window_cache)
  1338. and not output_attentions
  1339. ):
  1340. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1341. attention_mask,
  1342. inputs_embeds=input_tensor,
  1343. past_key_values_length=past_seen_tokens,
  1344. sliding_window=self.config.sliding_window,
  1345. is_training=self.training,
  1346. ):
  1347. return None
  1348. dtype, device = input_tensor.dtype, input_tensor.device
  1349. min_dtype = torch.finfo(dtype).min
  1350. sequence_length = input_tensor.shape[1]
  1351. # SlidingWindowCache or StaticCache
  1352. if using_sliding_window_cache or using_static_cache:
  1353. target_length = past_key_values.get_max_cache_shape()
  1354. # DynamicCache or no cache
  1355. else:
  1356. target_length = (
  1357. attention_mask.shape[-1]
  1358. if isinstance(attention_mask, torch.Tensor)
  1359. else past_seen_tokens + sequence_length + 1
  1360. )
  1361. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1362. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1363. attention_mask,
  1364. sequence_length=sequence_length,
  1365. target_length=target_length,
  1366. dtype=dtype,
  1367. device=device,
  1368. cache_position=cache_position,
  1369. batch_size=input_tensor.shape[0],
  1370. config=self.config,
  1371. past_key_values=past_key_values,
  1372. )
  1373. if (
  1374. self.config._attn_implementation == "sdpa"
  1375. and attention_mask is not None
  1376. and attention_mask.device.type == "cuda"
  1377. and not output_attentions
  1378. ):
  1379. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1380. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1381. # Details: https://github.com/pytorch/pytorch/issues/110213
  1382. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1383. return causal_mask
  1384. @staticmethod
  1385. # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Moshi
  1386. def _prepare_4d_causal_attention_mask_with_cache_position(
  1387. attention_mask: torch.Tensor,
  1388. sequence_length: int,
  1389. target_length: int,
  1390. dtype: torch.dtype,
  1391. device: torch.device,
  1392. cache_position: torch.Tensor,
  1393. batch_size: int,
  1394. config: MoshiConfig,
  1395. past_key_values: Cache,
  1396. ):
  1397. """
  1398. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1399. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1400. Args:
  1401. attention_mask (`torch.Tensor`):
  1402. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  1403. sequence_length (`int`):
  1404. The sequence length being processed.
  1405. target_length (`int`):
  1406. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  1407. dtype (`torch.dtype`):
  1408. The dtype to use for the 4D attention mask.
  1409. device (`torch.device`):
  1410. The device to plcae the 4D attention mask on.
  1411. cache_position (`torch.Tensor`):
  1412. Indices depicting the position of the input sequence tokens in the sequence.
  1413. batch_size (`torch.Tensor`):
  1414. Batch size.
  1415. config (`MoshiConfig`):
  1416. The model's configuration class
  1417. past_key_values (`Cache`):
  1418. The cache class that is being used currently to generate
  1419. """
  1420. if attention_mask is not None and attention_mask.dim() == 4:
  1421. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1422. causal_mask = attention_mask
  1423. else:
  1424. min_dtype = torch.finfo(dtype).min
  1425. causal_mask = torch.full(
  1426. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1427. )
  1428. diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1429. if config.sliding_window is not None:
  1430. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  1431. # the check is needed to verify is current checkpoint was trained with sliding window or not
  1432. if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
  1433. sliding_attend_mask = torch.arange(target_length, device=device) <= (
  1434. cache_position.reshape(-1, 1) - config.sliding_window
  1435. )
  1436. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  1437. causal_mask *= diagonal_attend_mask
  1438. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1439. if attention_mask is not None:
  1440. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1441. if attention_mask.shape[-1] > target_length:
  1442. attention_mask = attention_mask[:, :target_length]
  1443. mask_length = attention_mask.shape[-1]
  1444. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1445. padding_mask = padding_mask == 0
  1446. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1447. padding_mask, min_dtype
  1448. )
  1449. return causal_mask
  1450. @add_start_docstrings(
  1451. "The Moshi decoder model with a text language modelling head on top. Only usable for text.",
  1452. MOSHI_START_DOCSTRING,
  1453. )
  1454. class MoshiForCausalLM(MoshiPreTrainedModel, GenerationMixin):
  1455. _tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
  1456. # Copied from transformers.models.gemma.modeling_gemma.GemmaForCausalLM.__init__ with Gemma->Moshi
  1457. def __init__(self, config):
  1458. super().__init__(config)
  1459. self.model = MoshiModel(config)
  1460. self.vocab_size = config.vocab_size
  1461. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1462. # Initialize weights and apply final processing
  1463. self.post_init()
  1464. def get_input_embeddings(self):
  1465. return self.model.embed_tokens
  1466. def set_input_embeddings(self, value):
  1467. self.model.embed_tokens = value
  1468. def get_output_embeddings(self):
  1469. return self.lm_head
  1470. def set_output_embeddings(self, new_embeddings):
  1471. self.lm_head = new_embeddings
  1472. def set_decoder(self, decoder):
  1473. self.model = decoder
  1474. def get_decoder(self):
  1475. return self.model
  1476. @add_start_docstrings_to_model_forward(MOSHI_DECODER_INPUTS_DOCSTRING)
  1477. @replace_return_docstrings(output_type=MoshiCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1478. def forward(
  1479. self,
  1480. input_ids: torch.LongTensor = None,
  1481. attention_mask: Optional[torch.Tensor] = None,
  1482. position_ids: Optional[torch.LongTensor] = None,
  1483. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1484. inputs_embeds: Optional[torch.FloatTensor] = None,
  1485. use_cache: Optional[bool] = None,
  1486. output_attentions: Optional[bool] = None,
  1487. output_hidden_states: Optional[bool] = None,
  1488. return_dict: Optional[bool] = None,
  1489. cache_position: Optional[torch.LongTensor] = None,
  1490. labels: Optional[torch.LongTensor] = None,
  1491. num_logits_to_keep: int = 0,
  1492. ) -> Union[Tuple, MoshiCausalLMOutputWithPast]:
  1493. r"""
  1494. Args:
  1495. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1496. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1497. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1498. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1499. num_logits_to_keep (`int`, *optional*):
  1500. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1501. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1502. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1503. Returns:
  1504. Example:
  1505. ```python
  1506. >>> from transformers import AutoTokenizer, MoshiForCausalLM
  1507. >>> model = MoshiForCausalLM.from_pretrained("kmhf/hf-moshiko")
  1508. >>> tokenizer = AutoTokenizer.from_pretrained("kmhf/hf-moshiko")
  1509. >>> prompt = "What is your favorite condiment?"
  1510. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1511. >>> # Generate
  1512. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1513. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1514. "What is your favorite condiment?"
  1515. ```"""
  1516. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1517. output_hidden_states = (
  1518. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1519. )
  1520. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1521. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1522. outputs = self.model(
  1523. input_ids=input_ids,
  1524. attention_mask=attention_mask,
  1525. position_ids=position_ids,
  1526. past_key_values=past_key_values,
  1527. inputs_embeds=inputs_embeds,
  1528. use_cache=use_cache,
  1529. output_attentions=output_attentions,
  1530. output_hidden_states=output_hidden_states,
  1531. return_dict=return_dict,
  1532. cache_position=cache_position,
  1533. )
  1534. hidden_states = outputs[0]
  1535. if labels is None and not is_torchdynamo_compiling():
  1536. logger.warning_once(
  1537. "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)"
  1538. )
  1539. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1540. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  1541. loss = None
  1542. if labels is not None:
  1543. # Upcast to float if we need to compute the loss to avoid potential precision issues
  1544. logits = logits.float()
  1545. # Shift so that tokens < n predict n
  1546. shift_logits = logits[..., :-1, :].contiguous()
  1547. shift_labels = labels[..., 1:].contiguous()
  1548. # Flatten the tokens
  1549. loss_fct = CrossEntropyLoss()
  1550. shift_logits = shift_logits.view(-1, self.config.vocab_size)
  1551. shift_labels = shift_labels.view(-1)
  1552. # Enable model parallelism
  1553. shift_labels = shift_labels.to(shift_logits.device)
  1554. loss = loss_fct(shift_logits, shift_labels)
  1555. if not return_dict:
  1556. output = (
  1557. logits,
  1558. hidden_states,
  1559. ) + outputs[1:]
  1560. return (loss,) + output if loss is not None else output
  1561. return MoshiCausalLMOutputWithPast(
  1562. loss=loss,
  1563. logits=logits,
  1564. last_hidden_state=hidden_states, # Ignore copy
  1565. past_key_values=outputs.past_key_values,
  1566. hidden_states=outputs.hidden_states,
  1567. attentions=outputs.attentions,
  1568. )
  1569. @add_start_docstrings(
  1570. "The original Moshi model with an audio encoder, a Moshi depth decoder and a Moshi decoder, "
  1571. "for speech-to-speech.",
  1572. MOSHI_START_DOCSTRING,
  1573. )
  1574. class MoshiForConditionalGeneration(MoshiPreTrainedModel, GenerationMixin):
  1575. _tied_weights_keys = ["decoder.model.embed_tokens.weight", "decoder.lm_head.weight"]
  1576. config_class = MoshiConfig
  1577. main_input_name = "input_ids"
  1578. supports_gradient_checkpointing = True
  1579. _supports_flash_attn_2 = True
  1580. _supports_sdpa = True
  1581. def __init__(self, config: MoshiConfig):
  1582. super().__init__(config)
  1583. # We have 2 * num_codebooks audio embedding layers because we have the user input channel and the model output channel.
  1584. self.embed_tokens = nn.ModuleList(
  1585. [nn.Embedding(config.audio_vocab_size + 1, config.hidden_size) for _ in range(2 * config.num_codebooks)]
  1586. )
  1587. self.audio_encoder = AutoModel.from_config(
  1588. config.audio_encoder_config, attn_implementation=config._attn_implementation
  1589. )
  1590. self.decoder = MoshiForCausalLM(config)
  1591. config.depth_decoder_config._attn_implementation_internal = config._attn_implementation
  1592. self.depth_decoder = MoshiDepthDecoder(config.depth_decoder_config)
  1593. self.num_codebooks = config.num_codebooks
  1594. self.post_init()
  1595. def get_audio_encoder(self):
  1596. return self.audio_encoder
  1597. def get_depth_decoder(self):
  1598. return self.depth_decoder
  1599. def get_decoder(self):
  1600. return self.decoder
  1601. @add_start_docstrings_to_model_forward(MOSHI_INPUTS_DOCSTRING)
  1602. @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
  1603. def forward(
  1604. self,
  1605. input_ids: Optional[torch.LongTensor] = None,
  1606. attention_mask: Optional[torch.BoolTensor] = None,
  1607. user_input_values: Optional[torch.FloatTensor] = None,
  1608. user_audio_codes: Optional[torch.Tensor] = None,
  1609. moshi_input_values: Optional[torch.FloatTensor] = None,
  1610. moshi_audio_codes: Optional[torch.Tensor] = None,
  1611. past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
  1612. inputs_embeds: Optional[torch.FloatTensor] = None,
  1613. text_labels: Optional[torch.LongTensor] = None,
  1614. audio_labels: Optional[torch.LongTensor] = None,
  1615. use_cache: Optional[bool] = None,
  1616. output_attentions: Optional[bool] = None,
  1617. output_hidden_states: Optional[bool] = None,
  1618. return_dict: Optional[bool] = None,
  1619. **kwargs,
  1620. ) -> Union[Tuple, Seq2SeqLMOutput]:
  1621. r"""
  1622. Returns:
  1623. Examples:
  1624. ```python
  1625. >>> from transformers import MoshiForConditionalGeneration
  1626. >>> import torch
  1627. >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko")
  1628. >>> inputs = moshi.get_unconditional_inputs()
  1629. >>> logits = model(**inputs, ).logits
  1630. >>> logits.shape # (bsz, seq_len, text_vocab_size)
  1631. torch.Size([1, 1, 32000])
  1632. ```"""
  1633. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1634. kwargs_audio_encoder = {
  1635. argument[len("audio_encoder_")]: value
  1636. for argument, value in kwargs.items()
  1637. if argument.startswith("audio_encoder_")
  1638. }
  1639. kwargs_decoder = {
  1640. argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
  1641. }
  1642. kwargs_depth_decoder = {
  1643. argument[len("depth_decoder_") :]: value
  1644. for argument, value in kwargs.items()
  1645. if argument.startswith("depth_decoder_")
  1646. }
  1647. # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used
  1648. if inputs_embeds is None:
  1649. if user_input_values is not None and user_audio_codes is None:
  1650. user_audio_codes = self.audio_encoder.encode(
  1651. user_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder
  1652. )[0]
  1653. if moshi_input_values is not None and moshi_audio_codes is None:
  1654. moshi_audio_codes = self.audio_encoder.encode(
  1655. moshi_input_values, num_quantizers=self.num_codebooks, **kwargs_audio_encoder
  1656. )[0]
  1657. audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1)
  1658. if input_ids is None and audio_codes is None:
  1659. raise ValueError(
  1660. "You must provide at least one of `input_ids`, `inputs_embeds`, `input_values` and `audio_codes`."
  1661. )
  1662. if input_ids is not None:
  1663. inputs_embeds = self.decoder.model.embed_tokens(input_ids)
  1664. if audio_codes is not None:
  1665. audio_inputs_embeds = sum(
  1666. [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]
  1667. )
  1668. inputs_embeds = (
  1669. audio_inputs_embeds
  1670. if inputs_embeds is None
  1671. else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device)
  1672. )
  1673. # Decode
  1674. decoder_outputs = self.decoder(
  1675. attention_mask=attention_mask,
  1676. inputs_embeds=inputs_embeds,
  1677. output_attentions=output_attentions,
  1678. output_hidden_states=output_hidden_states,
  1679. use_cache=use_cache,
  1680. past_key_values=past_key_values,
  1681. return_dict=True,
  1682. labels=text_labels,
  1683. **kwargs_decoder,
  1684. )
  1685. decoder_last_hidden_state = decoder_outputs.last_hidden_state
  1686. depth_decoder_outputs = None
  1687. final_loss = decoder_outputs.loss
  1688. if text_labels is not None and audio_labels is not None:
  1689. # To use depth decoder forward here, we actually need oracle input ids since we're supposed to pass the true input ids
  1690. audio_labels = self.build_delay_pattern_mask(
  1691. audio_labels,
  1692. bos_token_id=self.config.audio_vocab_size,
  1693. pad_token_id=self.config.audio_vocab_size,
  1694. max_length=audio_labels.shape[-1] + 1,
  1695. )[0]
  1696. # (batch_size, sequence_length) -> (batch_size * sequence_length, 1)
  1697. text_labels = text_labels.view(-1, 1)
  1698. # (batch_size, num_codebooks, sequence_length) -> (batch_size * sequence_length, num_codebooks)
  1699. audio_labels = audio_labels.transpose(1, 2).reshape(-1, audio_labels.shape[1])
  1700. depth_input_ids = torch.cat([text_labels, audio_labels], dim=1)
  1701. # keep the last codebook out of input_ids
  1702. depth_input_ids = depth_input_ids[:, :-1]
  1703. # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim)
  1704. decoder_last_hidden_state = decoder_last_hidden_state.view(-1, 1, decoder_last_hidden_state.shape[-1])
  1705. depth_decoder_outputs = self.depth_decoder(
  1706. last_hidden_state=decoder_last_hidden_state,
  1707. input_ids=depth_input_ids,
  1708. attention_mask=attention_mask,
  1709. labels=audio_labels,
  1710. **kwargs_depth_decoder,
  1711. )
  1712. final_loss += depth_decoder_outputs.loss
  1713. if not return_dict:
  1714. outputs = decoder_outputs.to_tuple()
  1715. if depth_decoder_outputs is not None:
  1716. outputs += depth_decoder_outputs.to_tuple()
  1717. return outputs
  1718. return MoshiConditionalGenerationOutputWithPast(
  1719. loss=decoder_outputs.loss,
  1720. logits=decoder_outputs.logits,
  1721. last_hidden_state=decoder_last_hidden_state,
  1722. past_key_values=decoder_outputs.past_key_values,
  1723. hidden_states=decoder_outputs.hidden_states,
  1724. attentions=decoder_outputs.attentions,
  1725. depth_loss=None if depth_decoder_outputs is None else depth_decoder_outputs.loss,
  1726. audio_logits=None if depth_decoder_outputs is None else depth_decoder_outputs.logits,
  1727. depth_past_key_values=None if decoder_outputs is None else decoder_outputs.past_key_values,
  1728. depth_hidden_states=None if decoder_outputs is None else decoder_outputs.hidden_states,
  1729. depth_attentions=None if decoder_outputs is None else decoder_outputs.attentions,
  1730. )
  1731. def _prepare_inputs_embeds_for_generation(
  1732. self,
  1733. input_ids: Optional[torch.LongTensor] = None,
  1734. user_input_values: Optional[torch.FloatTensor] = None,
  1735. user_audio_codes: Optional[torch.Tensor] = None,
  1736. moshi_input_values: Optional[torch.FloatTensor] = None,
  1737. moshi_audio_codes: Optional[torch.Tensor] = None,
  1738. inputs_embeds: Optional[torch.FloatTensor] = None,
  1739. attention_mask: Optional[torch.Tensor] = None,
  1740. generation_config: Optional[GenerationConfig] = None,
  1741. apply_delay_pattern_mask: bool = False,
  1742. concat_unconditional_inputs: bool = False,
  1743. ):
  1744. user_delay_pattern_mask = None
  1745. moshi_delay_pattern_mask = None
  1746. if (
  1747. inputs_embeds is None
  1748. and input_ids is None
  1749. and user_input_values is None
  1750. and user_audio_codes is None
  1751. and moshi_input_values is None
  1752. and moshi_audio_codes is None
  1753. ):
  1754. raise ValueError(
  1755. "You must provide at least one of `input_ids`, `user_input_values`, `moshi_input_values`, `user_audio_codes`, `moshi_audio_codes` or `inputs_embeds`."
  1756. )
  1757. # in case inputs_embeds is passed, we might still need to create delay pattern masks
  1758. if inputs_embeds is None or apply_delay_pattern_mask:
  1759. if user_input_values is not None and user_audio_codes is None:
  1760. user_audio_codes = self.audio_encoder.encode(user_input_values, num_quantizers=self.num_codebooks)[0]
  1761. if moshi_input_values is not None and moshi_audio_codes is None:
  1762. moshi_audio_codes = self.audio_encoder.encode(moshi_input_values, num_quantizers=self.num_codebooks)[0]
  1763. if inputs_embeds is None and concat_unconditional_inputs:
  1764. unconditional_inputs = self.get_unconditional_inputs(num_samples=user_audio_codes.shape[0])
  1765. moshi_audio_codes = torch.cat([unconditional_inputs.moshi_audio_codes, moshi_audio_codes], dim=2)
  1766. user_audio_codes = torch.cat([unconditional_inputs.user_audio_codes, user_audio_codes], dim=2)
  1767. input_ids = torch.cat([unconditional_inputs.input_ids, input_ids], dim=1)
  1768. if attention_mask is not None:
  1769. attention_mask = torch.cat([unconditional_inputs.attention_mask, attention_mask], dim=1)
  1770. if inputs_embeds is None or apply_delay_pattern_mask:
  1771. if apply_delay_pattern_mask and user_audio_codes is not None:
  1772. user_audio_codes, user_delay_pattern_mask = self.build_delay_pattern_mask(
  1773. user_audio_codes,
  1774. bos_token_id=self.config.audio_vocab_size,
  1775. pad_token_id=self.config.audio_vocab_size,
  1776. max_length=generation_config.max_length,
  1777. )
  1778. if apply_delay_pattern_mask and moshi_audio_codes is not None:
  1779. moshi_audio_codes, moshi_delay_pattern_mask = self.build_delay_pattern_mask(
  1780. moshi_audio_codes,
  1781. bos_token_id=self.config.audio_vocab_size,
  1782. pad_token_id=self.config.audio_vocab_size,
  1783. max_length=generation_config.max_length,
  1784. )
  1785. # If inputs_embeds is provided, it has the priority over input_ids and audio_codes, which won't be used
  1786. if inputs_embeds is None:
  1787. audio_inputs_embeds = None
  1788. if user_audio_codes is not None and moshi_audio_codes is not None:
  1789. audio_codes = torch.cat([moshi_audio_codes, user_audio_codes], dim=1)
  1790. audio_inputs_embeds = sum(
  1791. [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]
  1792. )
  1793. elif moshi_audio_codes is not None:
  1794. audio_codes = moshi_audio_codes
  1795. audio_inputs_embeds = sum(
  1796. [self.embed_tokens[codebook](audio_codes[:, codebook]) for codebook in range(audio_codes.shape[1])]
  1797. )
  1798. elif user_audio_codes is not None:
  1799. audio_codes = user_audio_codes
  1800. audio_inputs_embeds = sum(
  1801. [
  1802. self.embed_tokens[codebook](audio_codes[:, codebook + self.num_codebooks])
  1803. for codebook in range(audio_codes.shape[1])
  1804. ]
  1805. )
  1806. if input_ids is not None:
  1807. inputs_embeds = self.decoder.model.embed_tokens(input_ids)
  1808. if audio_inputs_embeds is not None:
  1809. inputs_embeds = (
  1810. audio_inputs_embeds
  1811. if inputs_embeds is None
  1812. else audio_inputs_embeds + inputs_embeds.to(audio_inputs_embeds.device)
  1813. )
  1814. return (
  1815. inputs_embeds,
  1816. input_ids,
  1817. user_audio_codes,
  1818. moshi_audio_codes,
  1819. user_delay_pattern_mask,
  1820. moshi_delay_pattern_mask,
  1821. attention_mask,
  1822. )
  1823. @torch.no_grad()
  1824. def generate(
  1825. self,
  1826. input_ids: Optional[torch.LongTensor] = None,
  1827. user_input_values: Optional[torch.FloatTensor] = None,
  1828. user_audio_codes: Optional[torch.Tensor] = None,
  1829. moshi_input_values: Optional[torch.FloatTensor] = None,
  1830. moshi_audio_codes: Optional[torch.Tensor] = None,
  1831. inputs_embeds: Optional[torch.FloatTensor] = None,
  1832. return_audio_waveforms: Optional[bool] = True,
  1833. return_audio_codes: Optional[bool] = None,
  1834. concat_unconditional_inputs: Optional[bool] = True,
  1835. **kwargs,
  1836. ) -> torch.LongTensor:
  1837. """
  1838. Generates sequences of text token ids and audio tokens ids.
  1839. Parameters:
  1840. input_ids (`torch.Tensor `of shape `(batch_size, sequence_length), *optional*):
  1841. The sequence used as a text prompt for the generation.
  1842. user_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1843. The audio waveforms used as audio user prompt for the generation.
  1844. user_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1845. The audio codes used as audio user prompt for the generation. Has priority over `user_input_values` and represents the audio "tokens" of `user_input_values` once passed through the audio encoder.
  1846. moshi_input_values (`torch.Tensor `of shape `(batch_size, 1, audio_sequence_length), *optional*):
  1847. The audio waveforms used as audio Moshi prompt for the generation.
  1848. moshi_audio_codes (`torch.Tensor `of shape `(batch_size, num_codebooks, sequence_length), *optional*):
  1849. The audio codes used as audio Moshi prompt for the generation. Has priority over `moshi_input_values` and represents the audio "tokens" of `moshi_input_values` once passed through the audio encoder.
  1850. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1851. Optionally, instead of passing `input_ids` and the audio inputs you can choose to directly pass an embedded representation. This
  1852. is useful if you want more control over how to convert the inputs into associated vectors than the
  1853. model's internal embedding lookup matrix.
  1854. return_audio_waveforms (`bool`, *optional*, defaults to `True`):
  1855. If `False`, won't generate the audio waveforms.
  1856. return_audio_codes (`bool`, *optional*):
  1857. If `True`, will also returns the generated audio codes, i.e the intermediate audio "tokens" which transforms to `audio_sequences` once passed through the audio decoder.
  1858. concat_unconditional_inputs (`bool`, *optional*, defaults to `True`):
  1859. If `False`, won't concatenate initial audio and text tokens.
  1860. kwargs (`Dict[str, Any]`, *optional*):
  1861. Remaining dictionary of keyword arguments that are passed to the `generate` method. Refers to the
  1862. original [`generate` docstrings](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate)
  1863. for more information on how to use them.
  1864. Note that keywords with a *depth_* prefix will be input for the `generate` method of the
  1865. depth decoder. Otherwise, the latter will use its default generation config.
  1866. Return:
  1867. [`MoshiConditionalGenerationGenerateOutput`]
  1868. """
  1869. # multiple generate -> need to create/update device map
  1870. if hasattr(self, "hf_device_map") and not hasattr(self.depth_decoder, "hf_device_map"):
  1871. self.depth_decoder.hf_device_map = {}
  1872. if "" in self.hf_device_map:
  1873. self.depth_decoder.hf_device_map = self.hf_device_map
  1874. else:
  1875. main_device = [d for d in self.hf_device_map.values() if d not in ["cpu", "disk"]][0]
  1876. self.depth_decoder.hf_device_map = {
  1877. key[len("depth_decoder") :]: main_device if value in ["cpu", "disk"] else value
  1878. for key, value in self.hf_device_map.items()
  1879. if key.startswith("depth_decoder")
  1880. }
  1881. # need to remove depth_decoder from the top device_map so that we assign correctly the device for each layer idx in the cache
  1882. self.hf_device_map = {
  1883. key: value for key, value in self.hf_device_map.items() if not key.startswith("depth_decoder")
  1884. }
  1885. # retrieve depth decoder kwargs
  1886. depth_decoder_kwargs_keys = {argument for argument in kwargs if argument.startswith("depth_decoder_")}
  1887. kwargs_depth_decoder = {
  1888. argument[len("depth_decoder_") :]: kwargs.pop(argument) for argument in depth_decoder_kwargs_keys
  1889. }
  1890. # needs to prepare generation config, even though it'll be done again in `generate`
  1891. generation_config, kwargs = self._prepare_generation_config(kwargs.pop("generation_config", None), **kwargs)
  1892. input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs = (
  1893. self._check_and_maybe_initalize_inputs(
  1894. input_ids=input_ids,
  1895. user_input_values=user_input_values,
  1896. user_audio_codes=user_audio_codes,
  1897. moshi_input_values=moshi_input_values,
  1898. moshi_audio_codes=moshi_audio_codes,
  1899. inputs_embeds=inputs_embeds,
  1900. concat_unconditional_inputs=concat_unconditional_inputs,
  1901. )
  1902. )
  1903. inputs = inputs_embeds if input_ids is None else input_ids
  1904. input_ids_length = inputs.shape[-1] + 1 if concat_unconditional_inputs else inputs.shape[-1]
  1905. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  1906. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  1907. generation_config = self._prepare_generated_length(
  1908. generation_config=generation_config,
  1909. has_default_max_length=has_default_max_length,
  1910. has_default_min_length=has_default_min_length,
  1911. model_input_name="inputs_embeds" if input_ids is None else "input_ids",
  1912. inputs_tensor=inputs,
  1913. input_ids_length=input_ids_length,
  1914. )
  1915. # retrieve depth decoder generation config if it exists
  1916. if hasattr(generation_config, "depth_decoder_config"):
  1917. depth_decoder_generation_config = generation_config.depth_decoder_config
  1918. else:
  1919. # we need to control the number of tokens generated by the depth decoder
  1920. depth_decoder_generation_config = {
  1921. "min_length": self.num_codebooks + 1,
  1922. "max_length": self.num_codebooks + 1,
  1923. "cache_implementation": "sliding_window",
  1924. }
  1925. # update kwargs_depth_decoder: kwargs_depth_decoder have priority over depth_decoder_generation_config
  1926. depth_decoder_generation_config.update(kwargs_depth_decoder)
  1927. kwargs_depth_decoder = depth_decoder_generation_config
  1928. attention_mask = kwargs.pop("attention_mask", None)
  1929. (
  1930. inputs_embeds,
  1931. input_ids,
  1932. user_audio_codes,
  1933. moshi_audio_codes,
  1934. user_delay_pattern_mask,
  1935. moshi_delay_pattern_mask,
  1936. attention_mask,
  1937. ) = self._prepare_inputs_embeds_for_generation(
  1938. input_ids=input_ids,
  1939. user_input_values=user_input_values,
  1940. user_audio_codes=user_audio_codes,
  1941. moshi_input_values=moshi_input_values,
  1942. moshi_audio_codes=moshi_audio_codes,
  1943. inputs_embeds=inputs_embeds,
  1944. attention_mask=attention_mask,
  1945. generation_config=generation_config,
  1946. apply_delay_pattern_mask=True,
  1947. concat_unconditional_inputs=concat_unconditional_inputs,
  1948. )
  1949. # create blank user inputs - moshi needs a constant stream of user inputs
  1950. blank_input_values = torch.zeros(
  1951. (inputs_embeds.shape[0], 1, int(self.config.sampling_rate / self.config.audio_encoder_config.frame_rate)),
  1952. dtype=self.dtype,
  1953. device=self.device,
  1954. )
  1955. blank_user_audio_codes = self.audio_encoder.encode(blank_input_values, num_quantizers=self.num_codebooks)[0]
  1956. # set delay pattern mask for the rest of the generation
  1957. kwargs["user_delay_pattern_mask"] = (
  1958. user_delay_pattern_mask if user_delay_pattern_mask is not None else kwargs.get("user_delay_pattern_mask")
  1959. )
  1960. kwargs["moshi_delay_pattern_mask"] = (
  1961. moshi_delay_pattern_mask
  1962. if moshi_delay_pattern_mask is not None
  1963. else kwargs.get("moshi_delay_pattern_mask")
  1964. )
  1965. self.generated_audio_codes = torch.repeat_interleave(
  1966. moshi_audio_codes, max(generation_config.num_beams, generation_config.num_return_sequences), dim=0
  1967. )
  1968. return_dict_in_generate = generation_config.num_beams > 1 or generation_config.return_dict_in_generate
  1969. output_scores = generation_config.num_beams > 1 or generation_config.output_scores
  1970. outputs = super().generate(
  1971. inputs_embeds=inputs_embeds,
  1972. input_ids=input_ids,
  1973. generation_config=generation_config,
  1974. blank_user_audio_codes=blank_user_audio_codes,
  1975. kwargs_depth_decoder=kwargs_depth_decoder,
  1976. return_dict_in_generate=return_dict_in_generate,
  1977. output_scores=output_scores,
  1978. attention_mask=attention_mask,
  1979. **kwargs,
  1980. )
  1981. if not return_audio_waveforms and not return_audio_codes:
  1982. if return_dict_in_generate and not generation_config.return_dict_in_generate:
  1983. return outputs.sequences
  1984. return outputs
  1985. # check if outputs is a dict or tokens
  1986. if not return_dict_in_generate:
  1987. output_text_ids = outputs
  1988. else:
  1989. output_text_ids = outputs.sequences
  1990. if generation_config.num_return_sequences > 1:
  1991. moshi_delay_pattern_mask = torch.repeat_interleave(
  1992. moshi_delay_pattern_mask, generation_config.num_return_sequences, dim=0
  1993. )
  1994. if generation_config.num_beams > 1:
  1995. # we need to reorganize self.last_hidden_states and generated audio codes according to the beam_indices
  1996. # Beam indices are of shape `input_length + number_generated_tokens` but actually starts
  1997. # indexing indices at index 0 instead of index `input_length-1`.
  1998. # We thus discard the last `input_length` indices that are never used.
  1999. beam_indices = outputs.beam_indices[:, : -moshi_audio_codes.shape[-1]]
  2000. generated_audio_codes = self.generated_audio_codes[:, :, moshi_audio_codes.shape[-1] :]
  2001. # we've generated audio tokens `number_generated_tokens-1` times, so we use the corresponding beam indices to
  2002. # retrieve the right audio tokens
  2003. expanded_beam_indices = beam_indices[:, :-1].unsqueeze(1).expand(-1, self.num_codebooks, -1)
  2004. generated_audio_codes = torch.gather(generated_audio_codes, dim=0, index=expanded_beam_indices)
  2005. # now, rebuild generated audio codes, this time with the right beam tracking
  2006. moshi_audio_codes = torch.repeat_interleave(
  2007. moshi_audio_codes, generation_config.num_return_sequences, dim=0
  2008. )
  2009. self.generated_audio_codes = torch.cat((moshi_audio_codes, generated_audio_codes), dim=2)
  2010. # use the last beam indice to retrieve the right self.last_hidden_state
  2011. self.last_hidden_state = torch.index_select(self.last_hidden_state, dim=0, index=beam_indices[:, -1])
  2012. # we need to make a last generation with the latest generated tokens
  2013. last_hidden_state = self.last_hidden_state.view(-1, 1, self.last_hidden_state.shape[-1])
  2014. last_generated_audio_codes = self.depth_decoder.generate(
  2015. last_hidden_state=last_hidden_state,
  2016. input_ids=output_text_ids[:, -1:].view(-1, 1),
  2017. **kwargs_depth_decoder,
  2018. )
  2019. last_generated_audio_codes = last_generated_audio_codes[:, 1:].unsqueeze(2)
  2020. self.generated_audio_codes = torch.cat([self.generated_audio_codes, last_generated_audio_codes], dim=2)
  2021. # apply the pattern mask to the final audio ids
  2022. output_audio_codes = self.apply_delay_pattern_mask(self.generated_audio_codes, moshi_delay_pattern_mask)
  2023. # revert the pattern delay mask by filtering the pad token id and bos token ids
  2024. mask = moshi_delay_pattern_mask != self.config.audio_vocab_size
  2025. output_audio_codes = output_audio_codes[mask].reshape(mask.shape[0], self.num_codebooks, -1)
  2026. output_values = None
  2027. if return_audio_waveforms:
  2028. output_values = self.audio_encoder.decode(
  2029. output_audio_codes,
  2030. ).audio_values
  2031. output_audio_codes = output_audio_codes if return_audio_codes else None
  2032. if generation_config.return_dict_in_generate:
  2033. return MoshiConditionalGenerationGenerateOutput(
  2034. audio_sequences=output_values, audio_codes=output_audio_codes, **outputs
  2035. )
  2036. return MoshiConditionalGenerationGenerateOutput(
  2037. audio_sequences=output_values, sequences=output_text_ids, audio_codes=output_audio_codes
  2038. )
  2039. def prepare_inputs_for_generation(
  2040. self,
  2041. input_ids,
  2042. past_key_values=None,
  2043. attention_mask=None,
  2044. inputs_embeds=None,
  2045. cache_position=None,
  2046. position_ids=None,
  2047. use_cache=True,
  2048. num_logits_to_keep=None,
  2049. user_delay_pattern_mask=None,
  2050. moshi_delay_pattern_mask=None,
  2051. kwargs_depth_decoder=None,
  2052. blank_user_audio_codes: Optional[torch.FloatTensor] = None,
  2053. **kwargs,
  2054. ):
  2055. # Overwritten -- Moshi has custom post-processing
  2056. # 1. Do usual operations done on LLMs like Gemma - because we pre-processed inputs, the first pass always has inputs_embeds
  2057. model_inputs = super().prepare_inputs_for_generation(
  2058. input_ids=input_ids,
  2059. past_key_values=past_key_values,
  2060. attention_mask=attention_mask,
  2061. inputs_embeds=inputs_embeds,
  2062. cache_position=cache_position,
  2063. position_ids=position_ids,
  2064. use_cache=use_cache,
  2065. num_logits_to_keep=num_logits_to_keep,
  2066. **kwargs,
  2067. )
  2068. # 2. Now that everything is prepared, generate audio_codes using the depth decoder
  2069. # we want to do it after a first token has been generated
  2070. if model_inputs["input_ids"] is not None:
  2071. last_hidden_state = kwargs.get("last_hidden_state")
  2072. # (batch_size, sequence_length, dim) -> (batch_size * sequence_length, 1, dim)
  2073. last_hidden_state = last_hidden_state.view(-1, 1, last_hidden_state.shape[-1])
  2074. input_ids = model_inputs.pop("input_ids")
  2075. generated_audio_codes = self.depth_decoder.generate(
  2076. last_hidden_state=last_hidden_state,
  2077. input_ids=input_ids.view(-1, 1),
  2078. **kwargs_depth_decoder,
  2079. )
  2080. # the first tokens are text tokens
  2081. generated_audio_codes = generated_audio_codes[:, 1:].unsqueeze(2)
  2082. user_audio_codes = self.apply_delay_pattern_mask(
  2083. torch.cat(
  2084. [self.generated_audio_codes, blank_user_audio_codes.to(self.generated_audio_codes.device)], dim=2
  2085. ),
  2086. user_delay_pattern_mask,
  2087. )[:, :, -1:]
  2088. self.generated_audio_codes = self.apply_delay_pattern_mask(
  2089. torch.cat([self.generated_audio_codes, generated_audio_codes], dim=2), moshi_delay_pattern_mask
  2090. )
  2091. inputs_embeds, _, _, _, _, _, _ = self._prepare_inputs_embeds_for_generation(
  2092. input_ids, moshi_audio_codes=self.generated_audio_codes[:, :, -1:], user_audio_codes=user_audio_codes
  2093. )
  2094. model_inputs["input_ids"] = None
  2095. model_inputs["inputs_embeds"] = inputs_embeds
  2096. return model_inputs
  2097. def _update_model_kwargs_for_generation(
  2098. self,
  2099. outputs: ModelOutput,
  2100. model_kwargs: Dict[str, Any],
  2101. is_encoder_decoder: bool = False,
  2102. num_new_tokens: int = 1,
  2103. ) -> Dict[str, Any]:
  2104. model_kwargs = super()._update_model_kwargs_for_generation(
  2105. outputs, model_kwargs, is_encoder_decoder, num_new_tokens
  2106. )
  2107. # update last_hidden_state that'll be used in the depth decoder
  2108. model_kwargs["last_hidden_state"] = outputs.get("last_hidden_state")[:, -1:]
  2109. # dirty, but we need to make a last depth_decoder.generate
  2110. self.last_hidden_state = outputs.get("last_hidden_state")[:, -1:]
  2111. return model_kwargs
  2112. def get_input_embeddings(self):
  2113. return self.decoder.get_input_embeddings()
  2114. def set_input_embeddings(self, value):
  2115. self.decoder.set_input_embeddings(value)
  2116. def get_output_embeddings(self):
  2117. return self.decoder.get_output_embeddings()
  2118. def set_output_embeddings(self, new_embeddings):
  2119. self.decoder.set_output_embeddings(new_embeddings)
  2120. def freeze_audio_encoder(self):
  2121. """
  2122. Freeze the audio encoder weights.
  2123. """
  2124. for param in self.audio_encoder.parameters():
  2125. param.requires_grad = False
  2126. self.audio_encoder._requires_grad = False
  2127. def freeze_depth_decoder(self):
  2128. """
  2129. Freeze the depth encoder weights.
  2130. """
  2131. for param in self.depth_decoder.parameters():
  2132. param.requires_grad = False
  2133. self.depth_decoder._requires_grad = False
  2134. @staticmethod
  2135. # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM.apply_delay_pattern_mask
  2136. def apply_delay_pattern_mask(input_ids, decoder_pad_token_mask):
  2137. """Apply a delay pattern mask to the decoder input ids, only preserving predictions where
  2138. the mask is set to -1, and otherwise setting to the value detailed in the mask."""
  2139. seq_len = input_ids.shape[-1]
  2140. decoder_pad_token_mask = decoder_pad_token_mask[..., :seq_len]
  2141. input_ids = torch.where(decoder_pad_token_mask == -1, input_ids, decoder_pad_token_mask)
  2142. return input_ids
  2143. def build_delay_pattern_mask(
  2144. self, input_ids: torch.LongTensor, bos_token_id: int, pad_token_id: int, max_length: int = None
  2145. ):
  2146. """Build a delayed pattern mask to the input_ids. Each codebook, except the first one, is offset by
  2147. one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
  2148. are 4 codebooks and a max sequence length of 6, we have the delayed pattern mask of shape `(codebooks,
  2149. seq_len)`:
  2150. - [-1, -1, -1, -1, -1, P]
  2151. - [ B, -1, -1, -1, -1, -1]
  2152. - [ B, -1, -1, -1, -1, -1]
  2153. - [ B, -1, -1, -1, -1, -1]
  2154. where B is the begining-of-sentence token, P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
  2155. a prompt (input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
  2156. mask is set to the value in the prompt:
  2157. - [ a0, a1, -1, -1, -1, P]
  2158. - [ B, b0, b1, -1, -1, -1]
  2159. - [ B, c0, c1, -1, -1, -1]
  2160. - [ B, d0, d1, -1, -1, -1]
  2161. where a-d indicate the codebook channel and 0/1 indicates the temporality. Now, we only override the -1
  2162. tokens in our prediction.
  2163. """
  2164. bsz, num_codebooks, seq_len = input_ids.shape
  2165. max_length = max_length if max_length is not None else self.generation_config.max_length
  2166. input_ids_shifted = (
  2167. torch.ones((bsz, num_codebooks, max_length), dtype=torch.long, device=input_ids.device) * -1
  2168. )
  2169. # the first codebook channel is not shifted
  2170. seq_len_to_keep = min(seq_len, max_length - 1)
  2171. input_ids_shifted[:, 0, :seq_len_to_keep] = input_ids[:, 0, :seq_len_to_keep]
  2172. # fill the shifted ids with the prompt entries
  2173. input_ids_shifted[:, 1:, 1 : seq_len_to_keep + 1] = input_ids[:, 1:, :seq_len_to_keep]
  2174. # fill with BOS and PAD
  2175. input_ids_shifted[:, 1:, 0] = bos_token_id
  2176. input_ids_shifted[:, 0, -1] = pad_token_id
  2177. # construct a pattern mask that indicates the positions of BOS and PAD tokens for each codebook
  2178. pattern_mask = input_ids_shifted
  2179. input_ids = input_ids_shifted[..., :seq_len_to_keep]
  2180. return input_ids, pattern_mask
  2181. def get_unconditional_inputs(self, num_samples=1):
  2182. """
  2183. Helper function to get null inputs for unconditional generation, enabling the model to be used without the
  2184. feature extractor or tokenizer.
  2185. Args:
  2186. num_samples (int, *optional*):
  2187. Number of audio samples to unconditionally generate.
  2188. max_new_tokens (int, *optional*):
  2189. Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
  2190. longer inference (since more audio tokens need to be generated per sample).
  2191. Example:
  2192. ```python
  2193. >>> from transformers import MoshiForConditionalGeneration
  2194. >>> model = MoshiForConditionalGeneration.from_pretrained("kmhf/hf-moshiko-pytorch-bf16")
  2195. >>> # get the unconditional (or 'null') inputs for the model
  2196. >>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
  2197. >>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
  2198. ```"""
  2199. input_ids = torch.ones((num_samples, 1), device=self.device, dtype=torch.int64) * self.config.vocab_size
  2200. user_audio_codes = (
  2201. torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64)
  2202. * self.config.audio_vocab_size
  2203. )
  2204. moshi_audio_codes = (
  2205. torch.ones((num_samples, self.num_codebooks, 1), device=self.device, dtype=torch.int64)
  2206. * self.config.audio_vocab_size
  2207. )
  2208. attention_mask = torch.ones((num_samples, 1), device=self.device, dtype=torch.long)
  2209. return MoshiUnconditionalInput(
  2210. input_ids=input_ids,
  2211. user_audio_codes=user_audio_codes,
  2212. moshi_audio_codes=moshi_audio_codes,
  2213. attention_mask=attention_mask,
  2214. )
  2215. def _check_and_maybe_initalize_inputs(
  2216. self,
  2217. input_ids=None,
  2218. user_input_values=None,
  2219. user_audio_codes=None,
  2220. moshi_input_values=None,
  2221. moshi_audio_codes=None,
  2222. inputs_embeds=None,
  2223. concat_unconditional_inputs=None,
  2224. ):
  2225. inputs = input_ids if inputs_embeds is None else inputs_embeds
  2226. user_input = user_audio_codes if user_input_values is None else user_input_values
  2227. moshi_input = moshi_audio_codes if moshi_input_values is None else moshi_input_values
  2228. one_input_has_been_passed = (user_input is not None) or (moshi_input is not None) or (inputs is not None)
  2229. # concat_unconditional_inputs will be False if inputs_embeds is used
  2230. concat_unconditional_inputs = concat_unconditional_inputs and not (
  2231. inputs_embeds is not None and input_ids is None
  2232. )
  2233. # if one or two of the three required inputs have been passed, throws an error
  2234. if one_input_has_been_passed and (user_input is None):
  2235. raise ValueError(
  2236. "No user audio inputs have been passed alongside the other inputs. Make sure either `user_input_values` or `user_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2237. )
  2238. elif one_input_has_been_passed and (moshi_input is None):
  2239. raise ValueError(
  2240. "No Moshi audio inputs have been passed alongside the other inputs. Make sure either `moshi_input_values` or `moshi_audio_codes` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2241. )
  2242. elif one_input_has_been_passed and (inputs is None):
  2243. raise ValueError(
  2244. "No `input_ids` or `inputs_embeds` have been passed alongside the other inputs. Make sure `input_ids` is passed or use `MoshiForConditionalGeneration.get_unconditional_inputs`. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2245. )
  2246. elif not one_input_has_been_passed:
  2247. # if no inputs have been passed, use default values
  2248. unconditional_inputs = self.get_unconditional_inputs()
  2249. input_ids = unconditional_inputs.input_ids
  2250. user_audio_codes = unconditional_inputs.user_audio_codes
  2251. moshi_audio_codes = unconditional_inputs.moshi_audio_codes
  2252. # in that case, no need to concat unconditional inputs
  2253. concat_unconditional_inputs = False
  2254. else:
  2255. # check if same sequence length
  2256. user_seq_length = user_input.shape[-1]
  2257. moshi_seq_length = moshi_input.shape[-1]
  2258. tokens_seq_length = inputs.shape[1]
  2259. ratio = self.config.audio_encoder_config.frame_rate / self.config.sampling_rate
  2260. moshi_seq_length = math.ceil(moshi_seq_length * ratio) if moshi_audio_codes is None else moshi_seq_length
  2261. user_seq_length = math.ceil(user_seq_length * ratio) if user_audio_codes is None else user_seq_length
  2262. if tokens_seq_length != moshi_seq_length or tokens_seq_length != user_seq_length:
  2263. raise ValueError(
  2264. "At least one of the 3 inputs of `MoshiForConditionalGeneration` doesn't have the same sequence length as the others."
  2265. "Make sure that they all have the same sequence length. Check the `MoshiForConditionalGeneration` docstrings for more information."
  2266. )
  2267. return input_ids, user_audio_codes, moshi_audio_codes, concat_unconditional_inputs
  2268. @staticmethod
  2269. def _reorder_cache(
  2270. past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
  2271. ) -> Tuple[Tuple[torch.Tensor]]:
  2272. """
  2273. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  2274. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  2275. beam_idx at every generation step.
  2276. """
  2277. return tuple(
  2278. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
  2279. for layer_past in past_key_values
  2280. )
  2281. __all__ = ["MoshiForCausalLM", "MoshiForConditionalGeneration", "MoshiModel", "MoshiPreTrainedModel"]