modeling_jamba.py 79 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706
  1. # coding=utf-8
  2. # Copyright 2024 AI21 Labs Ltd. and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Jamba model."""
  21. import math
  22. from typing import Any, Dict, List, Optional, Tuple, Union
  23. import torch
  24. import torch.nn.functional as F
  25. import torch.utils.checkpoint
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import (
  31. AttentionMaskConverter,
  32. )
  33. from ...modeling_outputs import (
  34. MoeCausalLMOutputWithPast,
  35. MoeModelOutputWithPast,
  36. SequenceClassifierOutputWithPast,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import (
  40. add_start_docstrings,
  41. add_start_docstrings_to_model_forward,
  42. logging,
  43. replace_return_docstrings,
  44. )
  45. from ...utils.import_utils import (
  46. is_causal_conv1d_available,
  47. is_flash_attn_2_available,
  48. is_flash_attn_greater_or_equal_2_10,
  49. is_mamba_ssm_available,
  50. )
  51. from .configuration_jamba import JambaConfig
  52. if is_flash_attn_2_available():
  53. from ...modeling_flash_attention_utils import _flash_attention_forward
  54. if is_mamba_ssm_available():
  55. from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
  56. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  57. else:
  58. selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
  59. if is_causal_conv1d_available():
  60. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  61. else:
  62. causal_conv1d_update, causal_conv1d_fn = None, None
  63. is_fast_path_available = all(
  64. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  65. )
  66. logger = logging.get_logger(__name__)
  67. _CONFIG_FOR_DOC = "JambaConfig"
  68. # Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func with gate->router
  69. def load_balancing_loss_func(
  70. router_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
  71. num_experts: Optional[int] = None,
  72. top_k=2,
  73. attention_mask: Optional[torch.Tensor] = None,
  74. ) -> Union[torch.Tensor, int]:
  75. r"""
  76. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  77. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
  78. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  79. experts is too unbalanced.
  80. Args:
  81. router_logits:
  82. Logits from the `router`, should be a tuple of model.config.num_hidden_layers tensors of
  83. shape [batch_size X sequence_length, num_experts].
  84. num_experts:
  85. Number of experts
  86. top_k:
  87. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  88. parameter.
  89. attention_mask (`torch.Tensor`, *optional*):
  90. The attention_mask used in forward function
  91. shape [batch_size X sequence_length] if not None.
  92. Returns:
  93. The auxiliary loss.
  94. """
  95. if router_logits is None or not isinstance(router_logits, tuple):
  96. return 0
  97. if isinstance(router_logits, tuple):
  98. compute_device = router_logits[0].device
  99. concatenated_router_logits = torch.cat(
  100. [layer_router.to(compute_device) for layer_router in router_logits], dim=0
  101. )
  102. routing_weights = torch.nn.functional.softmax(concatenated_router_logits, dim=-1)
  103. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  104. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  105. if attention_mask is None:
  106. # Compute the percentage of tokens routed to each experts
  107. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  108. # Compute the average probability of routing to these experts
  109. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  110. else:
  111. batch_size, sequence_length = attention_mask.shape
  112. num_hidden_layers = concatenated_router_logits.shape[0] // (batch_size * sequence_length)
  113. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  114. expert_attention_mask = (
  115. attention_mask[None, :, :, None, None]
  116. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  117. .reshape(-1, top_k, num_experts)
  118. .to(compute_device)
  119. )
  120. # Compute the percentage of tokens routed to each experts
  121. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  122. expert_attention_mask, dim=0
  123. )
  124. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  125. router_per_expert_attention_mask = (
  126. attention_mask[None, :, :, None]
  127. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  128. .reshape(-1, num_experts)
  129. .to(compute_device)
  130. )
  131. # Compute the average probability of routing to these experts
  132. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  133. router_per_expert_attention_mask, dim=0
  134. )
  135. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  136. return overall_loss * num_experts
  137. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Jamba
  138. class JambaRMSNorm(nn.Module):
  139. def __init__(self, hidden_size, eps=1e-6):
  140. """
  141. JambaRMSNorm is equivalent to T5LayerNorm
  142. """
  143. super().__init__()
  144. self.weight = nn.Parameter(torch.ones(hidden_size))
  145. self.variance_epsilon = eps
  146. def forward(self, hidden_states):
  147. input_dtype = hidden_states.dtype
  148. hidden_states = hidden_states.to(torch.float32)
  149. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  150. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  151. return self.weight * hidden_states.to(input_dtype)
  152. def extra_repr(self):
  153. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  154. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  155. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  156. """
  157. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  158. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  159. """
  160. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  161. if n_rep == 1:
  162. return hidden_states
  163. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  164. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  165. class HybridMambaAttentionDynamicCache(DynamicCache):
  166. """
  167. A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
  168. (which has a constant shape regardless of seq_len).
  169. This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
  170. and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
  171. For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
  172. while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
  173. For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
  174. while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
  175. and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
  176. """
  177. def __init__(self, config, batch_size, dtype=torch.float16, device=None):
  178. super().__init__()
  179. self.dtype = dtype
  180. self.layers_block_type = config.layers_block_type
  181. self.has_previous_state = False # only used by mamba
  182. intermediate_size = config.mamba_expand * config.hidden_size
  183. ssm_state_size = config.mamba_d_state
  184. conv_kernel_size = config.mamba_d_conv
  185. self.conv_states = []
  186. self.ssm_states = []
  187. self.transformer_layers = []
  188. for i in range(config.num_hidden_layers):
  189. if self.layers_block_type[i] == "mamba":
  190. self.conv_states += [
  191. torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
  192. ]
  193. self.ssm_states += [
  194. torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
  195. ]
  196. else:
  197. self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
  198. self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
  199. self.transformer_layers.append(i)
  200. self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  201. self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
  202. def update(
  203. self,
  204. key_states: torch.Tensor,
  205. value_states: torch.Tensor,
  206. layer_idx: int,
  207. cache_kwargs: Optional[Dict[str, Any]] = None,
  208. ) -> Tuple[torch.Tensor, torch.Tensor]:
  209. # Update the cache
  210. if self.key_cache[layer_idx].shape[-1] == 0:
  211. self.key_cache[layer_idx] = key_states
  212. self.value_cache[layer_idx] = value_states
  213. else:
  214. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
  215. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
  216. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  217. def reorder_cache(self, beam_idx: torch.LongTensor):
  218. """Reorders the cache for beam search, given the selected beam indices."""
  219. for layer_idx in range(len(self.key_cache)):
  220. device = self.key_cache[layer_idx].device
  221. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  222. device = self.value_cache[layer_idx].device
  223. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  224. device = self.conv_states[layer_idx].device
  225. self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
  226. device = self.ssm_states[layer_idx].device
  227. self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
  228. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  229. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  230. # take any layer that contains cache and not empty tensor
  231. layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
  232. if len(self.key_cache) <= layer_idx:
  233. return 0
  234. return self.key_cache[layer_idx].shape[-2]
  235. def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
  236. raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
  237. @classmethod
  238. def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
  239. raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
  240. # Adapted from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Jamba
  241. class JambaAttention(nn.Module):
  242. """
  243. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  244. and "Generating Long Sequences with Sparse Transformers".
  245. """
  246. def __init__(self, config: JambaConfig, layer_idx: Optional[int] = None):
  247. super().__init__()
  248. self.config = config
  249. self.layer_idx = layer_idx
  250. if layer_idx is None:
  251. logger.warning_once(
  252. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  253. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  254. "when creating this class."
  255. )
  256. self.hidden_size = config.hidden_size
  257. self.num_heads = config.num_attention_heads
  258. self.head_dim = self.hidden_size // self.num_heads
  259. self.num_key_value_heads = config.num_key_value_heads
  260. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  261. self.is_causal = True
  262. self.attention_dropout = config.attention_dropout
  263. if (self.head_dim * self.num_heads) != self.hidden_size:
  264. raise ValueError(
  265. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  266. f" and `num_heads`: {self.num_heads})."
  267. )
  268. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  269. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  270. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  271. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  272. def forward(
  273. self,
  274. hidden_states: torch.Tensor,
  275. attention_mask: Optional[torch.Tensor] = None,
  276. position_ids: Optional[torch.LongTensor] = None,
  277. past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
  278. output_attentions: bool = False,
  279. use_cache: bool = False,
  280. cache_position: Optional[torch.LongTensor] = None,
  281. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  282. bsz, q_len, _ = hidden_states.size()
  283. query_states = self.q_proj(hidden_states)
  284. key_states = self.k_proj(hidden_states)
  285. value_states = self.v_proj(hidden_states)
  286. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  287. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  288. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  289. if past_key_value is not None:
  290. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  291. # repeat k/v heads if n_kv_heads < n_heads
  292. key_states = repeat_kv(key_states, self.num_key_value_groups)
  293. value_states = repeat_kv(value_states, self.num_key_value_groups)
  294. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  295. if attention_mask is not None: # no matter the length, we just slice it
  296. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  297. attn_weights = attn_weights + causal_mask
  298. # upcast attention to fp32
  299. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  300. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  301. attn_output = torch.matmul(attn_weights, value_states)
  302. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  303. raise ValueError(
  304. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  305. f" {attn_output.size()}"
  306. )
  307. attn_output = attn_output.transpose(1, 2).contiguous()
  308. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  309. attn_output = self.o_proj(attn_output)
  310. if not output_attentions:
  311. attn_weights = None
  312. return attn_output, attn_weights, past_key_value
  313. # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
  314. class JambaFlashAttention2(JambaAttention):
  315. """
  316. Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
  317. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  318. flash attention and deal with padding tokens in case the input contains any of them.
  319. """
  320. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  321. def __init__(self, *args, **kwargs):
  322. super().__init__(*args, **kwargs)
  323. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  324. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  325. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  326. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  327. def forward(
  328. self,
  329. hidden_states: torch.Tensor,
  330. attention_mask: Optional[torch.Tensor] = None,
  331. position_ids: Optional[torch.LongTensor] = None,
  332. past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
  333. output_attentions: bool = False,
  334. use_cache: bool = False,
  335. cache_position: Optional[torch.LongTensor] = None,
  336. **kwargs,
  337. ):
  338. bsz, q_len, _ = hidden_states.size()
  339. query_states = self.q_proj(hidden_states)
  340. key_states = self.k_proj(hidden_states)
  341. value_states = self.v_proj(hidden_states)
  342. # Flash attention requires the input to have the shape
  343. # batch_size x seq_length x head_dim x hidden_dim
  344. # therefore we just need to keep the original shape
  345. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
  346. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  347. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  348. if past_key_value is not None:
  349. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  350. # repeat k/v heads if n_kv_heads < n_heads
  351. key_states = repeat_kv(key_states, self.num_key_value_groups)
  352. value_states = repeat_kv(value_states, self.num_key_value_groups)
  353. dropout_rate = 0.0 if not self.training else self.attention_dropout
  354. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  355. # therefore the input hidden states gets silently casted in float32. Hence, we need
  356. # cast them back in float16 just to be sure everything works as expected.
  357. input_dtype = query_states.dtype
  358. if input_dtype == torch.float32:
  359. if torch.is_autocast_enabled():
  360. target_dtype = torch.get_autocast_gpu_dtype()
  361. # Handle the case where the model is quantized
  362. elif hasattr(self.config, "_pre_quantization_dtype"):
  363. target_dtype = self.config._pre_quantization_dtype
  364. else:
  365. target_dtype = self.q_proj.weight.dtype
  366. logger.warning_once(
  367. f"The input hidden states seems to be silently casted in float32, this might be related to"
  368. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  369. f" {target_dtype}."
  370. )
  371. query_states = query_states.to(target_dtype)
  372. key_states = key_states.to(target_dtype)
  373. value_states = value_states.to(target_dtype)
  374. # Reashape to the expected shape for Flash Attention
  375. key_states = key_states.transpose(1, 2)
  376. value_states = value_states.transpose(1, 2)
  377. attn_output = _flash_attention_forward(
  378. query_states,
  379. key_states,
  380. value_states,
  381. attention_mask,
  382. q_len,
  383. dropout=dropout_rate,
  384. sliding_window=getattr(self.config, "sliding_window", None),
  385. is_causal=self.is_causal,
  386. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  387. )
  388. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  389. attn_output = self.o_proj(attn_output)
  390. if not output_attentions:
  391. attn_weights = None
  392. return attn_output, attn_weights, past_key_value
  393. # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
  394. class JambaSdpaAttention(JambaAttention):
  395. """
  396. Jamba attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  397. `JambaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  398. SDPA API.
  399. """
  400. # Adapted from JambaAttention.forward
  401. def forward(
  402. self,
  403. hidden_states: torch.Tensor,
  404. attention_mask: Optional[torch.Tensor] = None,
  405. position_ids: Optional[torch.LongTensor] = None,
  406. past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
  407. output_attentions: bool = False,
  408. use_cache: bool = False,
  409. cache_position: Optional[torch.LongTensor] = None,
  410. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  411. if output_attentions:
  412. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  413. logger.warning_once(
  414. "JambaModel is using JambaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  415. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  416. )
  417. return super().forward(
  418. hidden_states=hidden_states,
  419. attention_mask=attention_mask,
  420. position_ids=position_ids,
  421. past_key_value=past_key_value,
  422. output_attentions=output_attentions,
  423. use_cache=use_cache,
  424. )
  425. bsz, q_len, _ = hidden_states.size()
  426. query_states = self.q_proj(hidden_states)
  427. key_states = self.k_proj(hidden_states)
  428. value_states = self.v_proj(hidden_states)
  429. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  430. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  431. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  432. if past_key_value is not None:
  433. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
  434. key_states = repeat_kv(key_states, self.num_key_value_groups)
  435. value_states = repeat_kv(value_states, self.num_key_value_groups)
  436. causal_mask = attention_mask
  437. if attention_mask is not None:
  438. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  439. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  440. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  441. if query_states.device.type == "cuda" and attention_mask is not None:
  442. query_states = query_states.contiguous()
  443. key_states = key_states.contiguous()
  444. value_states = value_states.contiguous()
  445. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  446. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  447. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  448. is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False
  449. attn_output = torch.nn.functional.scaled_dot_product_attention(
  450. query_states,
  451. key_states,
  452. value_states,
  453. attn_mask=causal_mask,
  454. dropout_p=self.attention_dropout if self.training else 0.0,
  455. is_causal=is_causal,
  456. )
  457. attn_output = attn_output.transpose(1, 2).contiguous()
  458. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  459. attn_output = self.o_proj(attn_output)
  460. return attn_output, None, past_key_value
  461. JAMBA_ATTENTION_CLASSES = {
  462. "eager": JambaAttention,
  463. "flash_attention_2": JambaFlashAttention2,
  464. "sdpa": JambaSdpaAttention,
  465. }
  466. # Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
  467. class JambaMambaMixer(nn.Module):
  468. """
  469. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  470. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  471. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  472. and is why Mamba is called **selective** state spaces)
  473. """
  474. def __init__(self, config: JambaConfig, layer_idx):
  475. super().__init__()
  476. self.config = config
  477. self.layer_idx = layer_idx
  478. self.hidden_size = config.hidden_size
  479. self.ssm_state_size = config.mamba_d_state
  480. self.conv_kernel_size = config.mamba_d_conv
  481. self.intermediate_size = config.mamba_expand * config.hidden_size
  482. self.time_step_rank = config.mamba_dt_rank
  483. self.use_conv_bias = config.mamba_conv_bias
  484. self.use_bias = config.mamba_proj_bias
  485. self.conv1d = nn.Conv1d(
  486. in_channels=self.intermediate_size,
  487. out_channels=self.intermediate_size,
  488. bias=self.use_conv_bias,
  489. kernel_size=self.conv_kernel_size,
  490. groups=self.intermediate_size,
  491. padding=self.conv_kernel_size - 1,
  492. )
  493. self.activation = config.hidden_act
  494. self.act = ACT2FN[config.hidden_act]
  495. self.use_fast_kernels = config.use_mamba_kernels
  496. # projection of the input hidden states
  497. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=self.use_bias)
  498. # selective projection used to make dt, B and C input dependant
  499. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  500. # time step projection (discretization)
  501. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  502. # S4D real initialization. These are not discretized!
  503. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  504. A = torch.arange(1, self.ssm_state_size + 1)[None, :]
  505. A = A.expand(self.intermediate_size, -1).contiguous()
  506. self.A_log = nn.Parameter(torch.log(A))
  507. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  508. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)
  509. self.dt_layernorm = JambaRMSNorm(self.time_step_rank, eps=config.rms_norm_eps)
  510. self.b_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  511. self.c_layernorm = JambaRMSNorm(self.ssm_state_size, eps=config.rms_norm_eps)
  512. if not is_fast_path_available:
  513. logger.warning_once(
  514. "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  515. " is None. To install follow https://github.com/state-spaces/mamba/#installation and"
  516. " https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
  517. )
  518. def cuda_kernels_forward(
  519. self,
  520. hidden_states: torch.Tensor,
  521. cache_params: HybridMambaAttentionDynamicCache = None,
  522. attention_mask: Optional[torch.LongTensor] = None,
  523. ):
  524. batch_size, seq_len, _ = hidden_states.shape
  525. use_precomputed_states = (
  526. cache_params is not None
  527. and cache_params.has_previous_state
  528. and seq_len == 1
  529. and cache_params.conv_states[self.layer_idx].shape[0]
  530. == cache_params.ssm_states[self.layer_idx].shape[0]
  531. == batch_size
  532. )
  533. # 1. Gated MLP's linear projection
  534. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  535. # We can't use `mamba_inner_fn` even if in training and without cache params because we have the
  536. # inner layernorms which isn't supported by this fused kernel
  537. hidden_states, gate = projected_states.chunk(2, dim=1)
  538. if attention_mask is not None:
  539. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  540. # 2. Convolution sequence transformation
  541. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  542. if use_precomputed_states:
  543. hidden_states = causal_conv1d_update(
  544. hidden_states.squeeze(-1),
  545. cache_params.conv_states[self.layer_idx],
  546. conv_weights,
  547. self.conv1d.bias,
  548. self.activation,
  549. )
  550. hidden_states = hidden_states.unsqueeze(-1)
  551. else:
  552. if cache_params is not None:
  553. conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
  554. cache_params.conv_states[self.layer_idx].copy_(conv_states)
  555. hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
  556. if attention_mask is not None:
  557. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  558. # 3. State Space Model sequence transformation
  559. # 3.a. input varying initialization of time_step, B and C
  560. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  561. time_step, B, C = torch.split(
  562. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  563. )
  564. time_step = self.dt_layernorm(time_step)
  565. B = self.b_layernorm(B)
  566. C = self.c_layernorm(C)
  567. # Here we need to apply dt_proj without the bias, as the bias is added in the selective scan kernel.
  568. # This is a hack to apply dt_proj while still using the forward pass of `torch.nn.Linear`, which is needed
  569. # in order to make quantization work. Quantization code replaces `torch.nn.Linear` layers with quantized
  570. # linear layers, and requires to call the forward pass directly.
  571. # Quantized model can't work with the original code:
  572. # ```discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)```
  573. time_proj_bias = self.dt_proj.bias.data
  574. with torch.no_grad():
  575. self.dt_proj.bias.data = torch.zeros_like(self.dt_proj.bias.data)
  576. discrete_time_step = self.dt_proj(time_step).transpose(1, 2)
  577. with torch.no_grad():
  578. self.dt_proj.bias.data = time_proj_bias
  579. A = -torch.exp(self.A_log.float())
  580. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  581. time_proj_bias = time_proj_bias.float() if time_proj_bias is not None else None
  582. if use_precomputed_states:
  583. scan_outputs = selective_state_update(
  584. cache_params.ssm_states[self.layer_idx],
  585. hidden_states[..., 0],
  586. discrete_time_step[..., 0],
  587. A,
  588. B[:, 0],
  589. C[:, 0],
  590. self.D,
  591. gate[..., 0],
  592. time_proj_bias,
  593. dt_softplus=True,
  594. ).unsqueeze(-1)
  595. else:
  596. scan_outputs, ssm_state = selective_scan_fn(
  597. hidden_states,
  598. discrete_time_step,
  599. A,
  600. B.transpose(1, 2),
  601. C.transpose(1, 2),
  602. self.D.float(),
  603. gate,
  604. time_proj_bias,
  605. delta_softplus=True,
  606. return_last_state=True,
  607. )
  608. if ssm_state is not None and cache_params is not None:
  609. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  610. # 4. Final linear projection
  611. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  612. return contextualized_states
  613. # fmt: off
  614. def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask: Optional[torch.LongTensor] = None):
  615. batch_size, seq_len, _ = input_states.shape
  616. dtype = input_states.dtype
  617. # 1. Gated MLP's linear projection
  618. projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
  619. hidden_states, gate = projected_states.chunk(2, dim=1)
  620. if attention_mask is not None:
  621. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  622. use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
  623. # 2. Convolution sequence transformation
  624. if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
  625. if self.training:
  626. # In training mode, we don't want to perform in-place operations on ssm_state so we can compute the backwards pass
  627. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  628. else:
  629. ssm_state = cache_params.ssm_states[self.layer_idx]
  630. ssm_state = ssm_state.to(hidden_states.device)
  631. if cache_params.has_previous_state and seq_len == 1 and \
  632. cache_params.conv_states[self.layer_idx].shape[0] == batch_size:
  633. conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
  634. conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
  635. conv_state[:, :, -1] = hidden_states[:, :, 0]
  636. cache_params.conv_states[self.layer_idx] = conv_state
  637. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  638. if self.use_conv_bias:
  639. hidden_states += self.conv1d.bias
  640. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
  641. else:
  642. conv_state = nn.functional.pad(
  643. hidden_states,
  644. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  645. )
  646. cache_params.conv_states[self.layer_idx] = conv_state
  647. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  648. else:
  649. ssm_state = torch.zeros(
  650. (batch_size, self.intermediate_size, self.ssm_state_size),
  651. device=hidden_states.device, dtype=dtype
  652. )
  653. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  654. if attention_mask is not None:
  655. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  656. # 3. State Space Model sequence transformation
  657. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  658. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  659. time_step, B, C = torch.split(
  660. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  661. )
  662. time_step = self.dt_layernorm(time_step)
  663. B = self.b_layernorm(B)
  664. C = self.c_layernorm(C)
  665. discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
  666. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
  667. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  668. A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
  669. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
  670. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
  671. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  672. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  673. scan_outputs = []
  674. for i in range(seq_len):
  675. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediate_size, ssm_state]
  676. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediate_size, 1]
  677. scan_outputs.append(scan_output[:, :, 0])
  678. scan_output = torch.stack(scan_outputs, dim=-1) # [batch, intermediate_size, seq_len]
  679. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  680. scan_output = (scan_output * self.act(gate))
  681. if use_cache:
  682. cache_params.ssm_states[self.layer_idx] = ssm_state
  683. # 4. Final linear projection
  684. contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
  685. return contextualized_states
  686. # fmt: on
  687. def forward(
  688. self,
  689. hidden_states,
  690. cache_params: HybridMambaAttentionDynamicCache = None,
  691. attention_mask: Optional[torch.LongTensor] = None,
  692. ):
  693. if self.use_fast_kernels:
  694. if not is_fast_path_available or "cuda" not in self.x_proj.weight.device.type:
  695. raise ValueError(
  696. "Fast Mamba kernels are not available. Make sure to they are installed and that the mamba module is on a CUDA device"
  697. )
  698. return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
  699. return self.slow_forward(hidden_states, cache_params, attention_mask)
  700. # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Jamba
  701. class JambaMLP(nn.Module):
  702. def __init__(self, config):
  703. super().__init__()
  704. self.hidden_size = config.hidden_size
  705. self.intermediate_size = config.intermediate_size
  706. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  707. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  708. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  709. self.act_fn = ACT2FN[config.hidden_act]
  710. def forward(self, hidden_state):
  711. return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
  712. # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock with Mistral->Jamba
  713. class JambaSparseMoeBlock(nn.Module):
  714. """
  715. This implementation is
  716. strictly equivalent to standard MoE with full capacity (no
  717. dropped tokens). It's faster since it formulates MoE operations
  718. in terms of block-sparse operations to accomodate imbalanced
  719. assignments of tokens to experts, whereas standard MoE either
  720. (1) drop tokens at the cost of reduced performance or (2) set
  721. capacity factor to number of experts and thus waste computation
  722. and memory on padding.
  723. """
  724. def __init__(self, config: JambaConfig):
  725. super().__init__()
  726. self.hidden_dim = config.hidden_size
  727. self.ffn_dim = config.intermediate_size
  728. self.num_experts = config.num_experts
  729. self.top_k = config.num_experts_per_tok
  730. self.router = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  731. self.experts = nn.ModuleList([JambaMLP(config) for _ in range(self.num_experts)])
  732. def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  733. """ """
  734. batch_size, sequence_length, hidden_dim = hidden_states.shape
  735. hidden_states = hidden_states.view(-1, hidden_dim)
  736. # router_logits: (batch * sequence_length, n_experts)
  737. router_logits = self.router(hidden_states)
  738. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  739. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  740. # we cast back to the input dtype
  741. routing_weights = routing_weights.to(hidden_states.dtype)
  742. final_hidden_states = torch.zeros(
  743. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  744. )
  745. # One hot encode the selected experts to create an expert mask
  746. # this will be used to easily index which expert is going to be sollicitated
  747. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  748. # Loop over all available experts in the model and perform the computation on each expert
  749. for expert_idx in range(self.num_experts):
  750. expert_layer = self.experts[expert_idx]
  751. idx, top_x = torch.where(expert_mask[expert_idx])
  752. if top_x.shape[0] == 0:
  753. continue
  754. # Index the correct hidden states and compute the expert hidden state for
  755. # the current expert. We need to make sure to multiply the output hidden
  756. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  757. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  758. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  759. # However `index_add_` only support torch tensors for indexing so we'll use
  760. # the `top_x` tensor here.
  761. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  762. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  763. return final_hidden_states, router_logits
  764. class JambaAttentionDecoderLayer(nn.Module):
  765. def __init__(self, config: JambaConfig, layer_idx: int):
  766. super().__init__()
  767. num_experts = config.layers_num_experts[layer_idx]
  768. self.self_attn = JAMBA_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  769. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  770. self.feed_forward = ffn_layer_class(config)
  771. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  772. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  773. def forward(
  774. self,
  775. hidden_states: torch.Tensor,
  776. attention_mask: Optional[torch.Tensor] = None,
  777. position_ids: Optional[torch.LongTensor] = None,
  778. past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
  779. output_attentions: Optional[bool] = False,
  780. output_router_logits: Optional[bool] = False,
  781. use_cache: Optional[bool] = False,
  782. cache_position: Optional[torch.LongTensor] = None,
  783. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  784. """
  785. Args:
  786. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  787. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  788. `(batch, sequence_length)` where padding elements are indicated by 0.
  789. past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  790. output_attentions (`bool`, *optional*):
  791. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  792. returned tensors for more detail.
  793. output_router_logits (`bool`, *optional*):
  794. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  795. should not be returned during inference.
  796. use_cache (`bool`, *optional*):
  797. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  798. (see `past_key_values`).
  799. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  800. Indices depicting the position of the input sequence tokens in the sequence.
  801. """
  802. residual = hidden_states
  803. hidden_states = self.input_layernorm(hidden_states)
  804. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  805. hidden_states=hidden_states,
  806. attention_mask=attention_mask,
  807. position_ids=position_ids,
  808. past_key_value=past_key_value,
  809. output_attentions=output_attentions,
  810. use_cache=use_cache,
  811. cache_position=cache_position,
  812. )
  813. # residual connection after attention
  814. hidden_states = residual + hidden_states
  815. # feed-forward (experts/MLP)
  816. residual = hidden_states
  817. hidden_states = self.pre_ff_layernorm(hidden_states)
  818. ff_outputs = self.feed_forward(hidden_states)
  819. if isinstance(ff_outputs, tuple):
  820. hidden_states, router_logits = ff_outputs
  821. else:
  822. hidden_states, router_logits = ff_outputs, None
  823. hidden_states = residual + hidden_states
  824. outputs = (hidden_states,)
  825. if output_attentions:
  826. outputs += (self_attn_weights,)
  827. if use_cache:
  828. outputs += (present_key_value,)
  829. if output_router_logits:
  830. outputs += (router_logits,)
  831. return outputs
  832. class JambaMambaDecoderLayer(nn.Module):
  833. def __init__(self, config: JambaConfig, layer_idx: int):
  834. super().__init__()
  835. num_experts = config.layers_num_experts[layer_idx]
  836. self.mamba = JambaMambaMixer(config=config, layer_idx=layer_idx)
  837. ffn_layer_class = JambaSparseMoeBlock if num_experts > 1 else JambaMLP
  838. self.feed_forward = ffn_layer_class(config)
  839. self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  840. self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  841. def forward(
  842. self,
  843. hidden_states: torch.Tensor,
  844. attention_mask: Optional[torch.Tensor] = None,
  845. position_ids: Optional[torch.LongTensor] = None,
  846. past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
  847. output_attentions: Optional[bool] = False,
  848. output_router_logits: Optional[bool] = False,
  849. use_cache: Optional[bool] = False,
  850. cache_position: Optional[torch.LongTensor] = None,
  851. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  852. """
  853. Args:
  854. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  855. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  856. `(batch, sequence_length)` where padding elements are indicated by 0.
  857. past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
  858. output_attentions (`bool`, *optional*):
  859. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  860. returned tensors for more detail.
  861. output_router_logits (`bool`, *optional*):
  862. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  863. should not be returned during inference.
  864. use_cache (`bool`, *optional*):
  865. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  866. (see `past_key_values`).
  867. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  868. Indices depicting the position of the input sequence tokens in the sequence.
  869. """
  870. residual = hidden_states
  871. hidden_states = self.input_layernorm(hidden_states)
  872. hidden_states = self.mamba(
  873. hidden_states=hidden_states,
  874. cache_params=past_key_value,
  875. attention_mask=attention_mask,
  876. )
  877. self_attn_weights = None
  878. # residual connection after mamba
  879. hidden_states = residual + hidden_states
  880. # feed-forward (experts/MLP)
  881. residual = hidden_states
  882. hidden_states = self.pre_ff_layernorm(hidden_states)
  883. ff_outputs = self.feed_forward(hidden_states)
  884. if isinstance(ff_outputs, tuple):
  885. hidden_states, router_logits = ff_outputs
  886. else:
  887. hidden_states, router_logits = ff_outputs, None
  888. hidden_states = residual + hidden_states
  889. outputs = (hidden_states,)
  890. if output_attentions:
  891. outputs += (self_attn_weights,)
  892. if use_cache:
  893. outputs += (past_key_value,)
  894. if output_router_logits:
  895. outputs += (router_logits,)
  896. return outputs
  897. JAMBA_START_DOCSTRING = r"""
  898. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  899. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  900. etc.)
  901. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  902. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  903. and behavior.
  904. Parameters:
  905. config ([`JambaConfig`]):
  906. Model configuration class with all the parameters of the model. Initializing with a config file does not
  907. load the weights associated with the model, only the configuration. Check out the
  908. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  909. """
  910. @add_start_docstrings(
  911. "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
  912. JAMBA_START_DOCSTRING,
  913. )
  914. class JambaPreTrainedModel(PreTrainedModel):
  915. config_class = JambaConfig
  916. base_model_prefix = "model"
  917. supports_gradient_checkpointing = True
  918. _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
  919. _skip_keys_device_placement = "past_key_values"
  920. _supports_flash_attn_2 = True
  921. _supports_sdpa = True
  922. _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
  923. _is_stateful = True
  924. def _init_weights(self, module):
  925. std = self.config.initializer_range
  926. if isinstance(module, (nn.Linear, nn.Conv1d)):
  927. module.weight.data.normal_(mean=0.0, std=std)
  928. if module.bias is not None:
  929. module.bias.data.zero_()
  930. elif isinstance(module, nn.Embedding):
  931. module.weight.data.normal_(mean=0.0, std=std)
  932. if module.padding_idx is not None:
  933. module.weight.data[module.padding_idx].zero_()
  934. JAMBA_INPUTS_DOCSTRING = r"""
  935. Args:
  936. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  937. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  938. it.
  939. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  940. [`PreTrainedTokenizer.__call__`] for details.
  941. [What are input IDs?](../glossary#input-ids)
  942. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  943. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  944. - 1 for tokens that are **not masked**,
  945. - 0 for tokens that are **masked**.
  946. [What are attention masks?](../glossary#attention-mask)
  947. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  948. [`PreTrainedTokenizer.__call__`] for details.
  949. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  950. `past_key_values`).
  951. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  952. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  953. information on the default strategy.
  954. - 1 indicates the head is **not masked**,
  955. - 0 indicates the head is **masked**.
  956. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  957. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  958. config.n_positions - 1]`.
  959. [What are position IDs?](../glossary#position-ids)
  960. past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  961. A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
  962. self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
  963. `past_key_values` input) to speed up sequential decoding.
  964. Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
  965. Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
  966. `(batch_size, d_inner, d_state)` respectively.
  967. See the `HybridMambaAttentionDynamicCache` class for more details.
  968. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
  969. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  970. `input_ids` of shape `(batch_size, sequence_length)`.
  971. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  972. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  973. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  974. model's internal embedding lookup matrix.
  975. use_cache (`bool`, *optional*):
  976. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  977. `past_key_values`).
  978. output_attentions (`bool`, *optional*):
  979. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  980. tensors for more detail.
  981. output_hidden_states (`bool`, *optional*):
  982. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  983. more detail.
  984. output_router_logits (`bool`, *optional*):
  985. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  986. should not be returned during inference.
  987. return_dict (`bool`, *optional*):
  988. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  989. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  990. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  991. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  992. the complete sequence length.
  993. """
  994. ALL_DECODER_LAYER_TYPES = {"attention": JambaAttentionDecoderLayer, "mamba": JambaMambaDecoderLayer}
  995. @add_start_docstrings(
  996. "The bare Jamba Model outputting raw hidden-states without any specific head on top.",
  997. JAMBA_START_DOCSTRING,
  998. )
  999. # Adapted from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->JAMBA, Mistral->Jamba
  1000. class JambaModel(JambaPreTrainedModel):
  1001. """
  1002. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`JambaDecoderLayer`]
  1003. Args:
  1004. config: JambaConfig
  1005. """
  1006. def __init__(self, config: JambaConfig):
  1007. super().__init__(config)
  1008. self.padding_idx = config.pad_token_id
  1009. self.vocab_size = config.vocab_size
  1010. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  1011. decoder_layers = []
  1012. for i in range(config.num_hidden_layers):
  1013. layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]]
  1014. decoder_layers.append(layer_class(config, layer_idx=i))
  1015. self.layers = nn.ModuleList(decoder_layers)
  1016. self._attn_implementation = config._attn_implementation
  1017. self.final_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1018. self.gradient_checkpointing = False
  1019. # Initialize weights and apply final processing
  1020. self.post_init()
  1021. def get_input_embeddings(self):
  1022. return self.embed_tokens
  1023. def set_input_embeddings(self, value):
  1024. self.embed_tokens = value
  1025. @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
  1026. def forward(
  1027. self,
  1028. input_ids: torch.LongTensor = None,
  1029. attention_mask: Optional[torch.Tensor] = None,
  1030. position_ids: Optional[torch.LongTensor] = None,
  1031. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  1032. inputs_embeds: Optional[torch.FloatTensor] = None,
  1033. use_cache: Optional[bool] = None,
  1034. output_attentions: Optional[bool] = None,
  1035. output_hidden_states: Optional[bool] = None,
  1036. output_router_logits: Optional[bool] = None,
  1037. return_dict: Optional[bool] = None,
  1038. cache_position: Optional[torch.LongTensor] = None,
  1039. ) -> Union[Tuple, MoeModelOutputWithPast]:
  1040. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1041. output_router_logits = (
  1042. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1043. )
  1044. output_hidden_states = (
  1045. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1046. )
  1047. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1048. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1049. if (input_ids is None) ^ (inputs_embeds is not None):
  1050. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1051. if self.gradient_checkpointing and self.training and use_cache:
  1052. logger.warning_once(
  1053. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  1054. )
  1055. use_cache = False
  1056. if inputs_embeds is None:
  1057. inputs_embeds = self.embed_tokens(input_ids)
  1058. hidden_states = inputs_embeds
  1059. if use_cache and past_key_values is None:
  1060. logger.warning_once(
  1061. "Jamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
  1062. "provided, so no cache will be returned."
  1063. )
  1064. if cache_position is None:
  1065. cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
  1066. if position_ids is None:
  1067. position_ids = cache_position.unsqueeze(0)
  1068. causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
  1069. mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
  1070. all_hidden_states = () if output_hidden_states else None
  1071. all_self_attns = () if output_attentions else None
  1072. all_router_logits = () if output_router_logits else None
  1073. for decoder_layer in self.layers:
  1074. # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
  1075. layer_mask = mamba_mask if isinstance(decoder_layer, JambaMambaDecoderLayer) else causal_mask
  1076. if output_hidden_states:
  1077. all_hidden_states += (hidden_states,)
  1078. if self.gradient_checkpointing and self.training:
  1079. layer_outputs = self._gradient_checkpointing_func(
  1080. decoder_layer.__call__,
  1081. hidden_states,
  1082. layer_mask,
  1083. position_ids,
  1084. past_key_values,
  1085. output_attentions,
  1086. output_router_logits,
  1087. use_cache,
  1088. cache_position,
  1089. )
  1090. else:
  1091. layer_outputs = decoder_layer(
  1092. hidden_states,
  1093. attention_mask=layer_mask,
  1094. position_ids=position_ids,
  1095. past_key_value=past_key_values,
  1096. output_attentions=output_attentions,
  1097. output_router_logits=output_router_logits,
  1098. use_cache=use_cache,
  1099. cache_position=cache_position,
  1100. )
  1101. hidden_states = layer_outputs[0]
  1102. if output_attentions:
  1103. if layer_outputs[1] is not None:
  1104. # append attentions only of attention layers. Mamba layers return `None` as the attention weights
  1105. all_self_attns += (layer_outputs[1],)
  1106. if output_router_logits:
  1107. if layer_outputs[-1] is not None:
  1108. # append router logits only of expert layers. Regular MLP layers return `None` as the router logits
  1109. all_router_logits += (layer_outputs[-1],)
  1110. hidden_states = self.final_layernorm(hidden_states)
  1111. # add hidden states from the last decoder layer
  1112. if output_hidden_states:
  1113. all_hidden_states += (hidden_states,)
  1114. if past_key_values and not past_key_values.has_previous_state:
  1115. past_key_values.has_previous_state = True
  1116. next_cache = None if not use_cache else past_key_values
  1117. if not return_dict:
  1118. return tuple(
  1119. v
  1120. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
  1121. if v is not None
  1122. )
  1123. return MoeModelOutputWithPast(
  1124. last_hidden_state=hidden_states,
  1125. past_key_values=next_cache,
  1126. hidden_states=all_hidden_states,
  1127. attentions=all_self_attns,
  1128. router_logits=all_router_logits,
  1129. )
  1130. def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
  1131. if self.config._attn_implementation == "flash_attention_2":
  1132. if attention_mask is not None and 0.0 in attention_mask:
  1133. return attention_mask
  1134. return None
  1135. dtype, device = input_tensor.dtype, input_tensor.device
  1136. min_dtype = torch.finfo(dtype).min
  1137. sequence_length = input_tensor.shape[1]
  1138. target_length = cache_position[-1] + 1
  1139. causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
  1140. if sequence_length != 1:
  1141. causal_mask = torch.triu(causal_mask, diagonal=1)
  1142. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1143. causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
  1144. if attention_mask is not None:
  1145. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1146. if attention_mask.dim() == 2:
  1147. mask_length = attention_mask.shape[-1]
  1148. padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
  1149. causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
  1150. if (
  1151. self.config._attn_implementation == "sdpa"
  1152. and attention_mask is not None
  1153. and attention_mask.device.type == "cuda"
  1154. ):
  1155. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1156. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1157. # Details: https://github.com/pytorch/pytorch/issues/110213
  1158. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1159. return causal_mask
  1160. def _update_mamba_mask(self, attention_mask, cache_position):
  1161. """
  1162. No need for zeroing states when
  1163. 1. Cached forward
  1164. 2. Attending to all inputs
  1165. """
  1166. mamba_mask = attention_mask
  1167. if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
  1168. mamba_mask = None
  1169. return mamba_mask
  1170. # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba
  1171. class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin):
  1172. _tied_weights_keys = ["lm_head.weight"]
  1173. def __init__(self, config: JambaConfig):
  1174. super().__init__(config)
  1175. self.model = JambaModel(config)
  1176. self.vocab_size = config.vocab_size
  1177. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1178. self.router_aux_loss_coef = config.router_aux_loss_coef
  1179. self.num_experts = config.num_experts
  1180. self.num_experts_per_tok = config.num_experts_per_tok
  1181. # Initialize weights and apply final processing
  1182. self.post_init()
  1183. def get_input_embeddings(self):
  1184. return self.model.embed_tokens
  1185. def set_input_embeddings(self, value):
  1186. self.model.embed_tokens = value
  1187. def get_output_embeddings(self):
  1188. return self.lm_head
  1189. def set_output_embeddings(self, new_embeddings):
  1190. self.lm_head = new_embeddings
  1191. def set_decoder(self, decoder):
  1192. self.model = decoder
  1193. def get_decoder(self):
  1194. return self.model
  1195. @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
  1196. @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1197. # Ignore copy
  1198. def forward(
  1199. self,
  1200. input_ids: torch.LongTensor = None,
  1201. attention_mask: Optional[torch.Tensor] = None,
  1202. position_ids: Optional[torch.LongTensor] = None,
  1203. past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
  1204. inputs_embeds: Optional[torch.FloatTensor] = None,
  1205. labels: Optional[torch.LongTensor] = None,
  1206. use_cache: Optional[bool] = None,
  1207. output_attentions: Optional[bool] = None,
  1208. output_hidden_states: Optional[bool] = None,
  1209. output_router_logits: Optional[bool] = None,
  1210. return_dict: Optional[bool] = None,
  1211. cache_position: Optional[torch.LongTensor] = None,
  1212. num_logits_to_keep: Optional[Union[int, None]] = None,
  1213. **loss_kwargs,
  1214. ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
  1215. r"""
  1216. Args:
  1217. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1218. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1219. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1220. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1221. num_logits_to_keep (`int` or `None`, *optional*):
  1222. Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
  1223. `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
  1224. can save memory, which becomes pretty significant for long sequences.
  1225. Returns:
  1226. Example:
  1227. ```python
  1228. >>> from transformers import AutoTokenizer, JambaForCausalLM
  1229. >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
  1230. >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
  1231. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1232. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1233. >>> # Generate
  1234. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1235. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1236. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1237. ```"""
  1238. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1239. output_router_logits = (
  1240. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1241. )
  1242. output_hidden_states = (
  1243. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1244. )
  1245. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1246. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1247. outputs = self.model(
  1248. input_ids=input_ids,
  1249. attention_mask=attention_mask,
  1250. position_ids=position_ids,
  1251. past_key_values=past_key_values,
  1252. inputs_embeds=inputs_embeds,
  1253. use_cache=use_cache,
  1254. output_attentions=output_attentions,
  1255. output_hidden_states=output_hidden_states,
  1256. output_router_logits=output_router_logits,
  1257. cache_position=cache_position,
  1258. return_dict=return_dict,
  1259. )
  1260. hidden_states = outputs[0]
  1261. if num_logits_to_keep is None:
  1262. logits = self.lm_head(hidden_states)
  1263. else:
  1264. logits = self.lm_head(hidden_states[..., -num_logits_to_keep:, :])
  1265. loss = None
  1266. if labels is not None:
  1267. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  1268. aux_loss = None
  1269. if output_router_logits:
  1270. aux_loss = load_balancing_loss_func(
  1271. outputs.router_logits if return_dict else outputs[-1],
  1272. self.num_experts,
  1273. self.num_experts_per_tok,
  1274. attention_mask,
  1275. )
  1276. if labels is not None:
  1277. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  1278. if not return_dict:
  1279. output = (logits,) + outputs[1:]
  1280. if output_router_logits:
  1281. output = (aux_loss,) + output
  1282. return (loss,) + output if loss is not None else output
  1283. return MoeCausalLMOutputWithPast(
  1284. loss=loss,
  1285. aux_loss=aux_loss,
  1286. logits=logits,
  1287. past_key_values=outputs.past_key_values,
  1288. hidden_states=outputs.hidden_states,
  1289. attentions=outputs.attentions,
  1290. router_logits=outputs.router_logits,
  1291. )
  1292. def prepare_inputs_for_generation(
  1293. self,
  1294. input_ids,
  1295. past_key_values=None,
  1296. attention_mask=None,
  1297. inputs_embeds=None,
  1298. output_router_logits=False,
  1299. cache_position=None,
  1300. position_ids=None,
  1301. use_cache=True,
  1302. **kwargs,
  1303. ):
  1304. # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
  1305. empty_past_kv = past_key_values is None
  1306. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  1307. # Exception 1: when passing input_embeds, input_ids may be missing entries
  1308. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  1309. if not empty_past_kv:
  1310. if inputs_embeds is not None: # Exception 1
  1311. input_ids = input_ids[:, -cache_position.shape[0] :]
  1312. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  1313. input_ids = input_ids[:, cache_position]
  1314. else:
  1315. past_key_values = HybridMambaAttentionDynamicCache(
  1316. self.config, input_ids.shape[0], self.dtype, device=self.device
  1317. )
  1318. if attention_mask is not None and position_ids is None:
  1319. # create position_ids on the fly for batch generation
  1320. position_ids = attention_mask.long().cumsum(-1) - 1
  1321. position_ids.masked_fill_(attention_mask == 0, 1)
  1322. if not empty_past_kv:
  1323. position_ids = position_ids[:, -input_ids.shape[1] :]
  1324. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  1325. if inputs_embeds is not None and empty_past_kv:
  1326. model_inputs = {"inputs_embeds": inputs_embeds}
  1327. else:
  1328. model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
  1329. model_inputs.update(
  1330. {
  1331. "position_ids": position_ids,
  1332. "past_key_values": past_key_values,
  1333. "use_cache": use_cache,
  1334. "attention_mask": attention_mask,
  1335. "output_router_logits": output_router_logits,
  1336. "num_logits_to_keep": self.config.num_logits_to_keep,
  1337. "cache_position": cache_position,
  1338. }
  1339. )
  1340. return model_inputs
  1341. @add_start_docstrings(
  1342. """
  1343. The Jamba Model with a sequence classification head on top (linear layer).
  1344. [`JambaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1345. (e.g. GPT-2) do.
  1346. Since it does classification on the last token, it requires to know the position of the last token. If a
  1347. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1348. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1349. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1350. each row of the batch).
  1351. """,
  1352. JAMBA_START_DOCSTRING,
  1353. )
  1354. # Copied from transformers.models.mixtral.modeling_mixtral.MixtralForSequenceClassification with Mixtral->Jamba, MIXTRAL->JAMBA
  1355. class JambaForSequenceClassification(JambaPreTrainedModel):
  1356. def __init__(self, config):
  1357. super().__init__(config)
  1358. self.num_labels = config.num_labels
  1359. self.model = JambaModel(config)
  1360. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1361. # Initialize weights and apply final processing
  1362. self.post_init()
  1363. def get_input_embeddings(self):
  1364. return self.model.embed_tokens
  1365. def set_input_embeddings(self, value):
  1366. self.model.embed_tokens = value
  1367. @add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING)
  1368. def forward(
  1369. self,
  1370. input_ids: Optional[torch.LongTensor] = None,
  1371. attention_mask: Optional[torch.Tensor] = None,
  1372. position_ids: Optional[torch.LongTensor] = None,
  1373. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1374. inputs_embeds: Optional[torch.FloatTensor] = None,
  1375. labels: Optional[torch.LongTensor] = None,
  1376. use_cache: Optional[bool] = None,
  1377. output_attentions: Optional[bool] = None,
  1378. output_hidden_states: Optional[bool] = None,
  1379. return_dict: Optional[bool] = None,
  1380. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1381. r"""
  1382. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1383. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1384. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1385. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1386. """
  1387. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1388. transformer_outputs = self.model(
  1389. input_ids,
  1390. attention_mask=attention_mask,
  1391. position_ids=position_ids,
  1392. past_key_values=past_key_values,
  1393. inputs_embeds=inputs_embeds,
  1394. use_cache=use_cache,
  1395. output_attentions=output_attentions,
  1396. output_hidden_states=output_hidden_states,
  1397. return_dict=return_dict,
  1398. )
  1399. hidden_states = transformer_outputs[0]
  1400. logits = self.score(hidden_states)
  1401. if input_ids is not None:
  1402. batch_size = input_ids.shape[0]
  1403. else:
  1404. batch_size = inputs_embeds.shape[0]
  1405. if self.config.pad_token_id is None and batch_size != 1:
  1406. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1407. if self.config.pad_token_id is None:
  1408. sequence_lengths = -1
  1409. else:
  1410. if input_ids is not None:
  1411. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1412. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1413. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1414. sequence_lengths = sequence_lengths.to(logits.device)
  1415. else:
  1416. sequence_lengths = -1
  1417. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1418. loss = None
  1419. if labels is not None:
  1420. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1421. if not return_dict:
  1422. output = (pooled_logits,) + transformer_outputs[1:]
  1423. return ((loss,) + output) if loss is not None else output
  1424. return SequenceClassifierOutputWithPast(
  1425. loss=loss,
  1426. logits=pooled_logits,
  1427. past_key_values=transformer_outputs.past_key_values,
  1428. hidden_states=transformer_outputs.hidden_states,
  1429. attentions=transformer_outputs.attentions,
  1430. )