modeling_mixtral.py 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616
  1. # coding=utf-8
  2. # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Mixtral model."""
  21. import math
  22. from typing import List, Optional, Tuple, Union
  23. import torch
  24. import torch.nn.functional as F
  25. import torch.utils.checkpoint
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
  31. from ...modeling_outputs import (
  32. MoeCausalLMOutputWithPast,
  33. MoeModelOutputWithPast,
  34. QuestionAnsweringModelOutput,
  35. SequenceClassifierOutputWithPast,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
  40. from ...utils import (
  41. add_code_sample_docstrings,
  42. add_start_docstrings,
  43. add_start_docstrings_to_model_forward,
  44. is_flash_attn_2_available,
  45. logging,
  46. replace_return_docstrings,
  47. )
  48. from ...utils.import_utils import is_torch_fx_available
  49. from .configuration_mixtral import MixtralConfig
  50. if is_flash_attn_2_available():
  51. from ...modeling_flash_attention_utils import _flash_attention_forward
  52. # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
  53. # It means that the function will not be traced through and simply appear as a node in the graph.
  54. if is_torch_fx_available():
  55. if not is_torch_greater_or_equal_than_1_13:
  56. import torch.fx
  57. _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
  58. logger = logging.get_logger(__name__)
  59. _CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
  60. _CONFIG_FOR_DOC = "MixtralConfig"
  61. def load_balancing_loss_func(
  62. gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None],
  63. num_experts: Optional[int] = None,
  64. top_k=2,
  65. attention_mask: Optional[torch.Tensor] = None,
  66. ) -> Union[torch.Tensor, int]:
  67. r"""
  68. Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  69. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
  70. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  71. experts is too unbalanced.
  72. Args:
  73. gate_logits:
  74. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  75. shape [batch_size X sequence_length, num_experts].
  76. num_experts:
  77. Number of experts
  78. top_k:
  79. The number of experts to route per-token, can be also interpreted as the `top-k` routing
  80. parameter.
  81. attention_mask (`torch.Tensor`, *optional*):
  82. The attention_mask used in forward function
  83. shape [batch_size X sequence_length] if not None.
  84. Returns:
  85. The auxiliary loss.
  86. """
  87. if gate_logits is None or not isinstance(gate_logits, tuple):
  88. return 0
  89. if isinstance(gate_logits, tuple):
  90. compute_device = gate_logits[0].device
  91. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  92. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  93. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  94. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  95. if attention_mask is None:
  96. # Compute the percentage of tokens routed to each experts
  97. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  98. # Compute the average probability of routing to these experts
  99. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  100. else:
  101. batch_size, sequence_length = attention_mask.shape
  102. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  103. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  104. expert_attention_mask = (
  105. attention_mask[None, :, :, None, None]
  106. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  107. .reshape(-1, top_k, num_experts)
  108. .to(compute_device)
  109. )
  110. # Compute the percentage of tokens routed to each experts
  111. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  112. expert_attention_mask, dim=0
  113. )
  114. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  115. router_per_expert_attention_mask = (
  116. attention_mask[None, :, :, None]
  117. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  118. .reshape(-1, num_experts)
  119. .to(compute_device)
  120. )
  121. # Compute the average probability of routing to these experts
  122. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  123. router_per_expert_attention_mask, dim=0
  124. )
  125. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  126. return overall_loss * num_experts
  127. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral
  128. class MixtralRMSNorm(nn.Module):
  129. def __init__(self, hidden_size, eps=1e-6):
  130. """
  131. MixtralRMSNorm is equivalent to T5LayerNorm
  132. """
  133. super().__init__()
  134. self.weight = nn.Parameter(torch.ones(hidden_size))
  135. self.variance_epsilon = eps
  136. def forward(self, hidden_states):
  137. input_dtype = hidden_states.dtype
  138. hidden_states = hidden_states.to(torch.float32)
  139. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  140. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  141. return self.weight * hidden_states.to(input_dtype)
  142. def extra_repr(self):
  143. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  144. # copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
  145. # TODO @longjie no longer copied from Mistral after static cache
  146. class MixtralRotaryEmbedding(nn.Module):
  147. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  148. super().__init__()
  149. self.dim = dim
  150. self.max_position_embeddings = max_position_embeddings
  151. self.base = base
  152. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
  153. self.register_buffer("inv_freq", inv_freq, persistent=False)
  154. # Build here to make `torch.jit.trace` work.
  155. self._set_cos_sin_cache(
  156. seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
  157. )
  158. def _set_cos_sin_cache(self, seq_len, device, dtype):
  159. self.max_seq_len_cached = seq_len
  160. t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
  161. freqs = torch.outer(t, self.inv_freq)
  162. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  163. emb = torch.cat((freqs, freqs), dim=-1)
  164. self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
  165. self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
  166. def forward(self, x, seq_len=None):
  167. # x: [bs, num_attention_heads, seq_len, head_size]
  168. if seq_len > self.max_seq_len_cached:
  169. self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
  170. return (
  171. self.cos_cached[:seq_len].to(dtype=x.dtype),
  172. self.sin_cached[:seq_len].to(dtype=x.dtype),
  173. )
  174. # Copied from transformers.models.llama.modeling_llama.rotate_half
  175. def rotate_half(x):
  176. """Rotates half the hidden dims of the input."""
  177. x1 = x[..., : x.shape[-1] // 2]
  178. x2 = x[..., x.shape[-1] // 2 :]
  179. return torch.cat((-x2, x1), dim=-1)
  180. # copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
  181. # TODO @longjie no longer copied from Mistral after static cache
  182. def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
  183. """Applies Rotary Position Embedding to the query and key tensors.
  184. Args:
  185. q (`torch.Tensor`): The query tensor.
  186. k (`torch.Tensor`): The key tensor.
  187. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  188. sin (`torch.Tensor`): The sine part of the rotary embedding.
  189. position_ids (`torch.Tensor`):
  190. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  191. used to pass offsetted position ids when working with a KV-cache.
  192. unsqueeze_dim (`int`, *optional*, defaults to 1):
  193. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  194. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  195. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  196. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  197. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  198. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  199. Returns:
  200. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  201. """
  202. cos = cos[position_ids].unsqueeze(unsqueeze_dim)
  203. sin = sin[position_ids].unsqueeze(unsqueeze_dim)
  204. q_embed = (q * cos) + (rotate_half(q) * sin)
  205. k_embed = (k * cos) + (rotate_half(k) * sin)
  206. return q_embed, k_embed
  207. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  208. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  209. """
  210. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  211. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  212. """
  213. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  214. if n_rep == 1:
  215. return hidden_states
  216. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  217. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  218. # copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral
  219. # TODO @longjie no longer copied from Mistral after static cache
  220. class MixtralAttention(nn.Module):
  221. """
  222. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  223. and "Generating Long Sequences with Sparse Transformers".
  224. """
  225. def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
  226. super().__init__()
  227. self.config = config
  228. self.layer_idx = layer_idx
  229. if layer_idx is None:
  230. logger.warning_once(
  231. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  232. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  233. "when creating this class."
  234. )
  235. self.hidden_size = config.hidden_size
  236. self.num_heads = config.num_attention_heads
  237. self.head_dim = self.hidden_size // self.num_heads
  238. self.num_key_value_heads = config.num_key_value_heads
  239. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  240. self.max_position_embeddings = config.max_position_embeddings
  241. self.rope_theta = config.rope_theta
  242. self.is_causal = True
  243. self.attention_dropout = config.attention_dropout
  244. if (self.head_dim * self.num_heads) != self.hidden_size:
  245. raise ValueError(
  246. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  247. f" and `num_heads`: {self.num_heads})."
  248. )
  249. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  250. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  251. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  252. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  253. self.rotary_emb = MixtralRotaryEmbedding(
  254. self.head_dim,
  255. max_position_embeddings=self.max_position_embeddings,
  256. base=self.rope_theta,
  257. )
  258. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  259. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  260. def forward(
  261. self,
  262. hidden_states: torch.Tensor,
  263. attention_mask: Optional[torch.Tensor] = None,
  264. position_ids: Optional[torch.LongTensor] = None,
  265. past_key_value: Optional[Cache] = None,
  266. output_attentions: bool = False,
  267. use_cache: bool = False,
  268. cache_position: Optional[torch.LongTensor] = None,
  269. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  270. bsz, q_len, _ = hidden_states.size()
  271. query_states = self.q_proj(hidden_states)
  272. key_states = self.k_proj(hidden_states)
  273. value_states = self.v_proj(hidden_states)
  274. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  275. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  276. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  277. kv_seq_len = key_states.shape[-2]
  278. if past_key_value is not None:
  279. if self.layer_idx is None:
  280. raise ValueError(
  281. f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
  282. "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
  283. "with a layer index."
  284. )
  285. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  286. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  287. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  288. if past_key_value is not None:
  289. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  290. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  291. # repeat k/v heads if n_kv_heads < n_heads
  292. key_states = repeat_kv(key_states, self.num_key_value_groups)
  293. value_states = repeat_kv(value_states, self.num_key_value_groups)
  294. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  295. if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
  296. raise ValueError(
  297. f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
  298. f" {attn_weights.size()}"
  299. )
  300. if attention_mask is not None: # no matter the length, we just slice it
  301. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  302. attn_weights = attn_weights + causal_mask
  303. # upcast attention to fp32
  304. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  305. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  306. attn_output = torch.matmul(attn_weights, value_states)
  307. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  308. raise ValueError(
  309. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  310. f" {attn_output.size()}"
  311. )
  312. attn_output = attn_output.transpose(1, 2).contiguous()
  313. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  314. attn_output = self.o_proj(attn_output)
  315. if not output_attentions:
  316. attn_weights = None
  317. return attn_output, attn_weights, past_key_value
  318. # copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
  319. # TODO @longjie no longer copied from Mistral after static cache
  320. class MixtralFlashAttention2(MixtralAttention):
  321. """
  322. Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
  323. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  324. flash attention and deal with padding tokens in case the input contains any of them.
  325. """
  326. def forward(
  327. self,
  328. hidden_states: torch.Tensor,
  329. attention_mask: Optional[torch.Tensor] = None,
  330. position_ids: Optional[torch.LongTensor] = None,
  331. past_key_value: Optional[Cache] = None,
  332. output_attentions: bool = False,
  333. use_cache: bool = False,
  334. cache_position: Optional[torch.LongTensor] = None,
  335. ):
  336. bsz, q_len, _ = hidden_states.size()
  337. query_states = self.q_proj(hidden_states)
  338. key_states = self.k_proj(hidden_states)
  339. value_states = self.v_proj(hidden_states)
  340. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  341. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  342. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  343. kv_seq_len = key_states.shape[-2]
  344. if past_key_value is not None:
  345. if self.layer_idx is None:
  346. raise ValueError(
  347. f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
  348. "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
  349. "with a layer index."
  350. )
  351. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  352. # Because the input can be padded, the absolute sequence length depends on the max position id.
  353. rotary_seq_len = (
  354. max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
  355. )
  356. cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
  357. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  358. if past_key_value is not None:
  359. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  360. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  361. # repeat k/v heads if n_kv_heads < n_heads
  362. key_states = repeat_kv(key_states, self.num_key_value_groups)
  363. value_states = repeat_kv(value_states, self.num_key_value_groups)
  364. dropout_rate = 0.0 if not self.training else self.attention_dropout
  365. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  366. # therefore the input hidden states gets silently casted in float32. Hence, we need
  367. # cast them back in float16 just to be sure everything works as expected.
  368. input_dtype = query_states.dtype
  369. if input_dtype == torch.float32:
  370. if torch.is_autocast_enabled():
  371. target_dtype = torch.get_autocast_gpu_dtype()
  372. # Handle the case where the model is quantized
  373. elif hasattr(self.config, "_pre_quantization_dtype"):
  374. target_dtype = self.config._pre_quantization_dtype
  375. else:
  376. target_dtype = self.q_proj.weight.dtype
  377. logger.warning_once(
  378. f"The input hidden states seems to be silently casted in float32, this might be related to"
  379. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  380. f" {target_dtype}."
  381. )
  382. query_states = query_states.to(target_dtype)
  383. key_states = key_states.to(target_dtype)
  384. value_states = value_states.to(target_dtype)
  385. # Reashape to the expected shape for Flash Attention
  386. query_states = query_states.transpose(1, 2)
  387. key_states = key_states.transpose(1, 2)
  388. value_states = value_states.transpose(1, 2)
  389. attn_output = _flash_attention_forward(
  390. query_states,
  391. key_states,
  392. value_states,
  393. attention_mask,
  394. q_len,
  395. position_ids=position_ids,
  396. dropout=dropout_rate,
  397. sliding_window=getattr(self.config, "sliding_window", None),
  398. is_causal=self.is_causal,
  399. )
  400. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  401. attn_output = self.o_proj(attn_output)
  402. if not output_attentions:
  403. attn_weights = None
  404. return attn_output, attn_weights, past_key_value
  405. # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
  406. # TODO @longjie no longer copied from Mistral after static cache
  407. class MixtralSdpaAttention(MixtralAttention):
  408. """
  409. Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  410. `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  411. SDPA API.
  412. """
  413. # Adapted from MixtralAttention.forward
  414. def forward(
  415. self,
  416. hidden_states: torch.Tensor,
  417. attention_mask: Optional[torch.Tensor] = None,
  418. position_ids: Optional[torch.LongTensor] = None,
  419. past_key_value: Optional[Cache] = None,
  420. output_attentions: bool = False,
  421. use_cache: bool = False,
  422. cache_position: Optional[torch.LongTensor] = None,
  423. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  424. if output_attentions:
  425. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  426. logger.warning_once(
  427. "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  428. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  429. )
  430. return super().forward(
  431. hidden_states=hidden_states,
  432. attention_mask=attention_mask,
  433. position_ids=position_ids,
  434. past_key_value=past_key_value,
  435. output_attentions=output_attentions,
  436. use_cache=use_cache,
  437. )
  438. bsz, q_len, _ = hidden_states.size()
  439. query_states = self.q_proj(hidden_states)
  440. key_states = self.k_proj(hidden_states)
  441. value_states = self.v_proj(hidden_states)
  442. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  443. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  444. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  445. kv_seq_len = key_states.shape[-2]
  446. if past_key_value is not None:
  447. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  448. cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
  449. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  450. if past_key_value is not None:
  451. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  452. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  453. key_states = repeat_kv(key_states, self.num_key_value_groups)
  454. value_states = repeat_kv(value_states, self.num_key_value_groups)
  455. causal_mask = attention_mask
  456. if attention_mask is not None: # no matter the length, we just slice it
  457. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  458. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  459. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  460. if query_states.device.type == "cuda" and attention_mask is not None:
  461. query_states = query_states.contiguous()
  462. key_states = key_states.contiguous()
  463. value_states = value_states.contiguous()
  464. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  465. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  466. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  467. is_causal = True if causal_mask is None and q_len > 1 else False
  468. attn_output = torch.nn.functional.scaled_dot_product_attention(
  469. query_states,
  470. key_states,
  471. value_states,
  472. attn_mask=causal_mask,
  473. dropout_p=self.attention_dropout if self.training else 0.0,
  474. is_causal=is_causal,
  475. )
  476. attn_output = attn_output.transpose(1, 2).contiguous()
  477. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  478. attn_output = self.o_proj(attn_output)
  479. return attn_output, None, past_key_value
  480. MIXTRAL_ATTENTION_CLASSES = {
  481. "eager": MixtralAttention,
  482. "flash_attention_2": MixtralFlashAttention2,
  483. "sdpa": MixtralSdpaAttention,
  484. }
  485. class MixtralBlockSparseTop2MLP(nn.Module):
  486. def __init__(self, config: MixtralConfig):
  487. super().__init__()
  488. self.ffn_dim = config.intermediate_size
  489. self.hidden_dim = config.hidden_size
  490. self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  491. self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
  492. self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
  493. self.act_fn = ACT2FN[config.hidden_act]
  494. def forward(self, hidden_states):
  495. current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
  496. current_hidden_states = self.w2(current_hidden_states)
  497. return current_hidden_states
  498. class MixtralSparseMoeBlock(nn.Module):
  499. """
  500. This implementation is
  501. strictly equivalent to standard MoE with full capacity (no
  502. dropped tokens). It's faster since it formulates MoE operations
  503. in terms of block-sparse operations to accomodate imbalanced
  504. assignments of tokens to experts, whereas standard MoE either
  505. (1) drop tokens at the cost of reduced performance or (2) set
  506. capacity factor to number of experts and thus waste computation
  507. and memory on padding.
  508. """
  509. def __init__(self, config):
  510. super().__init__()
  511. self.hidden_dim = config.hidden_size
  512. self.ffn_dim = config.intermediate_size
  513. self.num_experts = config.num_local_experts
  514. self.top_k = config.num_experts_per_tok
  515. # gating
  516. self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
  517. self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
  518. # Jitter parameters
  519. self.jitter_noise = config.router_jitter_noise
  520. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  521. """ """
  522. batch_size, sequence_length, hidden_dim = hidden_states.shape
  523. if self.training and self.jitter_noise > 0:
  524. hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
  525. hidden_states = hidden_states.view(-1, hidden_dim)
  526. # router_logits: (batch * sequence_length, n_experts)
  527. router_logits = self.gate(hidden_states)
  528. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  529. routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
  530. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  531. # we cast back to the input dtype
  532. routing_weights = routing_weights.to(hidden_states.dtype)
  533. final_hidden_states = torch.zeros(
  534. (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
  535. )
  536. # One hot encode the selected experts to create an expert mask
  537. # this will be used to easily index which expert is going to be sollicitated
  538. expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
  539. # Loop over all available experts in the model and perform the computation on each expert
  540. for expert_idx in range(self.num_experts):
  541. expert_layer = self.experts[expert_idx]
  542. idx, top_x = torch.where(expert_mask[expert_idx])
  543. # Index the correct hidden states and compute the expert hidden state for
  544. # the current expert. We need to make sure to multiply the output hidden
  545. # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
  546. current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
  547. current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
  548. # However `index_add_` only support torch tensors for indexing so we'll use
  549. # the `top_x` tensor here.
  550. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
  551. final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
  552. return final_hidden_states, router_logits
  553. class MixtralDecoderLayer(nn.Module):
  554. def __init__(self, config: MixtralConfig, layer_idx: int):
  555. super().__init__()
  556. self.hidden_size = config.hidden_size
  557. self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  558. self.block_sparse_moe = MixtralSparseMoeBlock(config)
  559. self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  560. self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  561. def forward(
  562. self,
  563. hidden_states: torch.Tensor,
  564. attention_mask: Optional[torch.Tensor] = None,
  565. position_ids: Optional[torch.LongTensor] = None,
  566. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  567. output_attentions: Optional[bool] = False,
  568. output_router_logits: Optional[bool] = False,
  569. use_cache: Optional[bool] = False,
  570. cache_position: Optional[torch.LongTensor] = None,
  571. **kwargs,
  572. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  573. """
  574. Args:
  575. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  576. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  577. `(batch, sequence_length)` where padding elements are indicated by 0.
  578. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  579. output_attentions (`bool`, *optional*):
  580. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  581. returned tensors for more detail.
  582. output_router_logits (`bool`, *optional*):
  583. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  584. should not be returned during inference.
  585. use_cache (`bool`, *optional*):
  586. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  587. (see `past_key_values`).
  588. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  589. Indices depicting the position of the input sequence tokens in the sequence.
  590. kwargs (`dict`, *optional*):
  591. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  592. into the model
  593. """
  594. residual = hidden_states
  595. hidden_states = self.input_layernorm(hidden_states)
  596. # Self Attention
  597. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  598. hidden_states=hidden_states,
  599. attention_mask=attention_mask,
  600. position_ids=position_ids,
  601. past_key_value=past_key_value,
  602. output_attentions=output_attentions,
  603. use_cache=use_cache,
  604. cache_position=cache_position,
  605. )
  606. hidden_states = residual + hidden_states
  607. # Fully Connected
  608. residual = hidden_states
  609. hidden_states = self.post_attention_layernorm(hidden_states)
  610. hidden_states, router_logits = self.block_sparse_moe(hidden_states)
  611. hidden_states = residual + hidden_states
  612. outputs = (hidden_states,)
  613. if output_attentions:
  614. outputs += (self_attn_weights,)
  615. if use_cache:
  616. outputs += (present_key_value,)
  617. if output_router_logits:
  618. outputs += (router_logits,)
  619. return outputs
  620. MIXTRAL_START_DOCSTRING = r"""
  621. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  622. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  623. etc.)
  624. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  625. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  626. and behavior.
  627. Parameters:
  628. config ([`MixtralConfig`]):
  629. Model configuration class with all the parameters of the model. Initializing with a config file does not
  630. load the weights associated with the model, only the configuration. Check out the
  631. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  632. """
  633. @add_start_docstrings(
  634. "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
  635. MIXTRAL_START_DOCSTRING,
  636. )
  637. # copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Mixtral
  638. # TODO (Raushan): bring back copied after compile compatibility
  639. class MixtralPreTrainedModel(PreTrainedModel):
  640. config_class = MixtralConfig
  641. base_model_prefix = "model"
  642. supports_gradient_checkpointing = True
  643. _no_split_modules = ["MixtralDecoderLayer"]
  644. _skip_keys_device_placement = "past_key_values"
  645. _supports_flash_attn_2 = True
  646. _supports_sdpa = True
  647. _supports_cache_class = True
  648. def _init_weights(self, module):
  649. std = self.config.initializer_range
  650. if isinstance(module, nn.Linear):
  651. module.weight.data.normal_(mean=0.0, std=std)
  652. if module.bias is not None:
  653. module.bias.data.zero_()
  654. elif isinstance(module, nn.Embedding):
  655. module.weight.data.normal_(mean=0.0, std=std)
  656. if module.padding_idx is not None:
  657. module.weight.data[module.padding_idx].zero_()
  658. MIXTRAL_INPUTS_DOCSTRING = r"""
  659. Args:
  660. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  661. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  662. it.
  663. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  664. [`PreTrainedTokenizer.__call__`] for details.
  665. [What are input IDs?](../glossary#input-ids)
  666. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  667. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  668. - 1 for tokens that are **not masked**,
  669. - 0 for tokens that are **masked**.
  670. [What are attention masks?](../glossary#attention-mask)
  671. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  672. [`PreTrainedTokenizer.__call__`] for details.
  673. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  674. `past_key_values`).
  675. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  676. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  677. information on the default strategy.
  678. - 1 indicates the head is **not masked**,
  679. - 0 indicates the head is **masked**.
  680. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  681. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  682. config.n_positions - 1]`.
  683. [What are position IDs?](../glossary#position-ids)
  684. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  685. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  686. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  687. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  688. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  689. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  690. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  691. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  692. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  693. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  694. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  695. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  696. model's internal embedding lookup matrix.
  697. use_cache (`bool`, *optional*):
  698. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  699. `past_key_values`).
  700. output_attentions (`bool`, *optional*):
  701. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  702. tensors for more detail.
  703. output_hidden_states (`bool`, *optional*):
  704. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  705. more detail.
  706. output_router_logits (`bool`, *optional*):
  707. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  708. should not be returned during inference.
  709. return_dict (`bool`, *optional*):
  710. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  711. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  712. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  713. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  714. the complete sequence length.
  715. """
  716. @add_start_docstrings(
  717. "The bare Mixtral Model outputting raw hidden-states without any specific head on top.",
  718. MIXTRAL_START_DOCSTRING,
  719. )
  720. # copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral
  721. # TODO @longjie no longer copied from Mistral after static cache
  722. class MixtralModel(MixtralPreTrainedModel):
  723. """
  724. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`]
  725. Args:
  726. config: MixtralConfig
  727. """
  728. def __init__(self, config: MixtralConfig):
  729. super().__init__(config)
  730. self.padding_idx = config.pad_token_id
  731. self.vocab_size = config.vocab_size
  732. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  733. self.layers = nn.ModuleList(
  734. [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  735. )
  736. self._attn_implementation = config._attn_implementation
  737. self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  738. self.gradient_checkpointing = False
  739. # Initialize weights and apply final processing
  740. self.post_init()
  741. def get_input_embeddings(self):
  742. return self.embed_tokens
  743. def set_input_embeddings(self, value):
  744. self.embed_tokens = value
  745. # Ignore copy
  746. @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
  747. def forward(
  748. self,
  749. input_ids: torch.LongTensor = None,
  750. attention_mask: Optional[torch.Tensor] = None,
  751. position_ids: Optional[torch.LongTensor] = None,
  752. past_key_values: Optional[List[torch.FloatTensor]] = None,
  753. inputs_embeds: Optional[torch.FloatTensor] = None,
  754. use_cache: Optional[bool] = None,
  755. output_attentions: Optional[bool] = None,
  756. output_hidden_states: Optional[bool] = None,
  757. output_router_logits: Optional[bool] = None,
  758. return_dict: Optional[bool] = None,
  759. cache_position: Optional[torch.LongTensor] = None,
  760. ) -> Union[Tuple, MoeModelOutputWithPast]:
  761. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  762. output_router_logits = (
  763. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  764. )
  765. output_hidden_states = (
  766. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  767. )
  768. use_cache = use_cache if use_cache is not None else self.config.use_cache
  769. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  770. if (input_ids is None) ^ (inputs_embeds is not None):
  771. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  772. if self.gradient_checkpointing and self.training:
  773. if use_cache:
  774. logger.warning_once(
  775. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  776. )
  777. use_cache = False
  778. # kept for BC (non `Cache` `past_key_values` inputs)
  779. return_legacy_cache = False
  780. if use_cache and not isinstance(past_key_values, Cache):
  781. return_legacy_cache = True
  782. if past_key_values is None:
  783. past_key_values = DynamicCache()
  784. else:
  785. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  786. logger.warning_once(
  787. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  788. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  789. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  790. )
  791. if inputs_embeds is None:
  792. inputs_embeds = self.embed_tokens(input_ids)
  793. if cache_position is None:
  794. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  795. cache_position = torch.arange(
  796. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  797. )
  798. if position_ids is None:
  799. position_ids = cache_position.unsqueeze(0)
  800. causal_mask = self._update_causal_mask(
  801. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  802. )
  803. hidden_states = inputs_embeds
  804. # decoder layers
  805. all_hidden_states = () if output_hidden_states else None
  806. all_self_attns = () if output_attentions else None
  807. all_router_logits = () if output_router_logits else None
  808. next_decoder_cache = None
  809. for decoder_layer in self.layers:
  810. if output_hidden_states:
  811. all_hidden_states += (hidden_states,)
  812. if self.gradient_checkpointing and self.training:
  813. layer_outputs = self._gradient_checkpointing_func(
  814. decoder_layer.__call__,
  815. hidden_states,
  816. causal_mask,
  817. position_ids,
  818. past_key_values,
  819. output_attentions,
  820. output_router_logits,
  821. use_cache,
  822. cache_position,
  823. )
  824. else:
  825. layer_outputs = decoder_layer(
  826. hidden_states,
  827. attention_mask=causal_mask,
  828. position_ids=position_ids,
  829. past_key_value=past_key_values,
  830. output_attentions=output_attentions,
  831. output_router_logits=output_router_logits,
  832. use_cache=use_cache,
  833. cache_position=cache_position,
  834. )
  835. hidden_states = layer_outputs[0]
  836. if use_cache:
  837. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  838. if output_attentions:
  839. all_self_attns += (layer_outputs[1],)
  840. if output_router_logits:
  841. all_router_logits += (layer_outputs[-1],)
  842. hidden_states = self.norm(hidden_states)
  843. # add hidden states from the last decoder layer
  844. if output_hidden_states:
  845. all_hidden_states += (hidden_states,)
  846. next_cache = next_decoder_cache if use_cache else None
  847. if return_legacy_cache:
  848. next_cache = next_cache.to_legacy_cache()
  849. if not return_dict:
  850. return tuple(
  851. v
  852. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
  853. if v is not None
  854. )
  855. return MoeModelOutputWithPast(
  856. last_hidden_state=hidden_states,
  857. past_key_values=next_cache,
  858. hidden_states=all_hidden_states,
  859. attentions=all_self_attns,
  860. router_logits=all_router_logits,
  861. )
  862. # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
  863. def _update_causal_mask(
  864. self,
  865. attention_mask: torch.Tensor,
  866. input_tensor: torch.Tensor,
  867. cache_position: torch.Tensor,
  868. past_key_values: Cache,
  869. output_attentions: bool,
  870. ):
  871. if self.config._attn_implementation == "flash_attention_2":
  872. if attention_mask is not None and 0.0 in attention_mask:
  873. return attention_mask
  874. return None
  875. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  876. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  877. # to infer the attention mask.
  878. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  879. using_static_cache = isinstance(past_key_values, StaticCache)
  880. using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
  881. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  882. if (
  883. self.config._attn_implementation == "sdpa"
  884. and not (using_static_cache or using_sliding_window_cache)
  885. and not output_attentions
  886. ):
  887. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  888. attention_mask,
  889. inputs_embeds=input_tensor,
  890. past_key_values_length=past_seen_tokens,
  891. sliding_window=self.config.sliding_window,
  892. is_training=self.training,
  893. ):
  894. return None
  895. dtype, device = input_tensor.dtype, input_tensor.device
  896. min_dtype = torch.finfo(dtype).min
  897. sequence_length = input_tensor.shape[1]
  898. # SlidingWindowCache or StaticCache
  899. if using_sliding_window_cache or using_static_cache:
  900. target_length = past_key_values.get_max_cache_shape()
  901. # DynamicCache or no cache
  902. else:
  903. target_length = (
  904. attention_mask.shape[-1]
  905. if isinstance(attention_mask, torch.Tensor)
  906. else past_seen_tokens + sequence_length + 1
  907. )
  908. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  909. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  910. attention_mask,
  911. sequence_length=sequence_length,
  912. target_length=target_length,
  913. dtype=dtype,
  914. device=device,
  915. cache_position=cache_position,
  916. batch_size=input_tensor.shape[0],
  917. config=self.config,
  918. past_key_values=past_key_values,
  919. )
  920. if (
  921. self.config._attn_implementation == "sdpa"
  922. and attention_mask is not None
  923. and attention_mask.device.type == "cuda"
  924. and not output_attentions
  925. ):
  926. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  927. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  928. # Details: https://github.com/pytorch/pytorch/issues/110213
  929. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  930. return causal_mask
  931. @staticmethod
  932. # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Mixtral
  933. def _prepare_4d_causal_attention_mask_with_cache_position(
  934. attention_mask: torch.Tensor,
  935. sequence_length: int,
  936. target_length: int,
  937. dtype: torch.dtype,
  938. device: torch.device,
  939. cache_position: torch.Tensor,
  940. batch_size: int,
  941. config: MixtralConfig,
  942. past_key_values: Cache,
  943. ):
  944. """
  945. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  946. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  947. Args:
  948. attention_mask (`torch.Tensor`):
  949. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  950. sequence_length (`int`):
  951. The sequence length being processed.
  952. target_length (`int`):
  953. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  954. dtype (`torch.dtype`):
  955. The dtype to use for the 4D attention mask.
  956. device (`torch.device`):
  957. The device to plcae the 4D attention mask on.
  958. cache_position (`torch.Tensor`):
  959. Indices depicting the position of the input sequence tokens in the sequence.
  960. batch_size (`torch.Tensor`):
  961. Batch size.
  962. config (`MixtralConfig`):
  963. The model's configuration class
  964. past_key_values (`Cache`):
  965. The cache class that is being used currently to generate
  966. """
  967. if attention_mask is not None and attention_mask.dim() == 4:
  968. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  969. causal_mask = attention_mask
  970. else:
  971. min_dtype = torch.finfo(dtype).min
  972. causal_mask = torch.full(
  973. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  974. )
  975. diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  976. if config.sliding_window is not None:
  977. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  978. # the check is needed to verify is current checkpoint was trained with sliding window or not
  979. if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
  980. sliding_attend_mask = torch.arange(target_length, device=device) <= (
  981. cache_position.reshape(-1, 1) - config.sliding_window
  982. )
  983. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  984. causal_mask *= diagonal_attend_mask
  985. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  986. if attention_mask is not None:
  987. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  988. if attention_mask.shape[-1] > target_length:
  989. attention_mask = attention_mask[:, :target_length]
  990. mask_length = attention_mask.shape[-1]
  991. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  992. padding_mask = padding_mask == 0
  993. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  994. padding_mask, min_dtype
  995. )
  996. return causal_mask
  997. class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin):
  998. _tied_weights_keys = ["lm_head.weight"]
  999. def __init__(self, config):
  1000. super().__init__(config)
  1001. self.model = MixtralModel(config)
  1002. self.vocab_size = config.vocab_size
  1003. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1004. self.router_aux_loss_coef = config.router_aux_loss_coef
  1005. self.num_experts = config.num_local_experts
  1006. self.num_experts_per_tok = config.num_experts_per_tok
  1007. # Initialize weights and apply final processing
  1008. self.post_init()
  1009. def get_input_embeddings(self):
  1010. return self.model.embed_tokens
  1011. def set_input_embeddings(self, value):
  1012. self.model.embed_tokens = value
  1013. def get_output_embeddings(self):
  1014. return self.lm_head
  1015. def set_output_embeddings(self, new_embeddings):
  1016. self.lm_head = new_embeddings
  1017. def set_decoder(self, decoder):
  1018. self.model = decoder
  1019. def get_decoder(self):
  1020. return self.model
  1021. @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
  1022. @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1023. # Ignore copy
  1024. def forward(
  1025. self,
  1026. input_ids: torch.LongTensor = None,
  1027. attention_mask: Optional[torch.Tensor] = None,
  1028. position_ids: Optional[torch.LongTensor] = None,
  1029. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1030. inputs_embeds: Optional[torch.FloatTensor] = None,
  1031. labels: Optional[torch.LongTensor] = None,
  1032. use_cache: Optional[bool] = None,
  1033. output_attentions: Optional[bool] = None,
  1034. output_hidden_states: Optional[bool] = None,
  1035. output_router_logits: Optional[bool] = None,
  1036. return_dict: Optional[bool] = None,
  1037. cache_position: Optional[torch.LongTensor] = None,
  1038. num_logits_to_keep: int = 0,
  1039. **loss_kwargs,
  1040. ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
  1041. r"""
  1042. Args:
  1043. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1044. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1045. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1046. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1047. num_logits_to_keep (`int`, *optional*):
  1048. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1049. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1050. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1051. Returns:
  1052. Example:
  1053. ```python
  1054. >>> from transformers import AutoTokenizer, MixtralForCausalLM
  1055. >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  1056. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")
  1057. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  1058. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1059. >>> # Generate
  1060. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1061. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1062. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1063. ```"""
  1064. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1065. output_router_logits = (
  1066. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1067. )
  1068. output_hidden_states = (
  1069. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1070. )
  1071. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1072. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1073. outputs = self.model(
  1074. input_ids=input_ids,
  1075. attention_mask=attention_mask,
  1076. position_ids=position_ids,
  1077. past_key_values=past_key_values,
  1078. inputs_embeds=inputs_embeds,
  1079. use_cache=use_cache,
  1080. output_attentions=output_attentions,
  1081. output_hidden_states=output_hidden_states,
  1082. output_router_logits=output_router_logits,
  1083. return_dict=return_dict,
  1084. cache_position=cache_position,
  1085. )
  1086. hidden_states = outputs[0]
  1087. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1088. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  1089. loss = None
  1090. if labels is not None:
  1091. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  1092. aux_loss = None
  1093. if output_router_logits:
  1094. aux_loss = load_balancing_loss_func(
  1095. outputs.router_logits if return_dict else outputs[-1],
  1096. self.num_experts,
  1097. self.num_experts_per_tok,
  1098. attention_mask,
  1099. )
  1100. if labels is not None:
  1101. loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device
  1102. if not return_dict:
  1103. output = (logits,) + outputs[1:]
  1104. if output_router_logits:
  1105. output = (aux_loss,) + output
  1106. return (loss,) + output if loss is not None else output
  1107. return MoeCausalLMOutputWithPast(
  1108. loss=loss,
  1109. aux_loss=aux_loss,
  1110. logits=logits,
  1111. past_key_values=outputs.past_key_values,
  1112. hidden_states=outputs.hidden_states,
  1113. attentions=outputs.attentions,
  1114. router_logits=outputs.router_logits,
  1115. )
  1116. @add_start_docstrings(
  1117. """
  1118. The Mixtral Model transformer with a sequence classification head on top (linear layer).
  1119. [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1120. (e.g. GPT-2) do.
  1121. Since it does classification on the last token, it requires to know the position of the last token. If a
  1122. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1123. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1124. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1125. each row of the batch).
  1126. """,
  1127. MIXTRAL_START_DOCSTRING,
  1128. )
  1129. # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL
  1130. class MixtralForSequenceClassification(MixtralPreTrainedModel):
  1131. def __init__(self, config):
  1132. super().__init__(config)
  1133. self.num_labels = config.num_labels
  1134. self.model = MixtralModel(config)
  1135. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1136. # Initialize weights and apply final processing
  1137. self.post_init()
  1138. def get_input_embeddings(self):
  1139. return self.model.embed_tokens
  1140. def set_input_embeddings(self, value):
  1141. self.model.embed_tokens = value
  1142. @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
  1143. def forward(
  1144. self,
  1145. input_ids: Optional[torch.LongTensor] = None,
  1146. attention_mask: Optional[torch.Tensor] = None,
  1147. position_ids: Optional[torch.LongTensor] = None,
  1148. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1149. inputs_embeds: Optional[torch.FloatTensor] = None,
  1150. labels: Optional[torch.LongTensor] = None,
  1151. use_cache: Optional[bool] = None,
  1152. output_attentions: Optional[bool] = None,
  1153. output_hidden_states: Optional[bool] = None,
  1154. return_dict: Optional[bool] = None,
  1155. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1156. r"""
  1157. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1158. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1159. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1160. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1161. """
  1162. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1163. transformer_outputs = self.model(
  1164. input_ids,
  1165. attention_mask=attention_mask,
  1166. position_ids=position_ids,
  1167. past_key_values=past_key_values,
  1168. inputs_embeds=inputs_embeds,
  1169. use_cache=use_cache,
  1170. output_attentions=output_attentions,
  1171. output_hidden_states=output_hidden_states,
  1172. return_dict=return_dict,
  1173. )
  1174. hidden_states = transformer_outputs[0]
  1175. logits = self.score(hidden_states)
  1176. if input_ids is not None:
  1177. batch_size = input_ids.shape[0]
  1178. else:
  1179. batch_size = inputs_embeds.shape[0]
  1180. if self.config.pad_token_id is None and batch_size != 1:
  1181. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1182. if self.config.pad_token_id is None:
  1183. sequence_lengths = -1
  1184. else:
  1185. if input_ids is not None:
  1186. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1187. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1188. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1189. sequence_lengths = sequence_lengths.to(logits.device)
  1190. else:
  1191. sequence_lengths = -1
  1192. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1193. loss = None
  1194. if labels is not None:
  1195. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1196. if not return_dict:
  1197. output = (pooled_logits,) + transformer_outputs[1:]
  1198. return ((loss,) + output) if loss is not None else output
  1199. return SequenceClassifierOutputWithPast(
  1200. loss=loss,
  1201. logits=pooled_logits,
  1202. past_key_values=transformer_outputs.past_key_values,
  1203. hidden_states=transformer_outputs.hidden_states,
  1204. attentions=transformer_outputs.attentions,
  1205. )
  1206. @add_start_docstrings(
  1207. """
  1208. The Mixtral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
  1209. output) e.g. for Named-Entity-Recognition (NER) tasks.
  1210. """,
  1211. MIXTRAL_START_DOCSTRING,
  1212. )
  1213. # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Mixtral, LLAMA->MIXTRAL
  1214. class MixtralForTokenClassification(MixtralPreTrainedModel):
  1215. def __init__(self, config):
  1216. super().__init__(config)
  1217. self.num_labels = config.num_labels
  1218. self.model = MixtralModel(config)
  1219. if getattr(config, "classifier_dropout", None) is not None:
  1220. classifier_dropout = config.classifier_dropout
  1221. elif getattr(config, "hidden_dropout", None) is not None:
  1222. classifier_dropout = config.hidden_dropout
  1223. else:
  1224. classifier_dropout = 0.1
  1225. self.dropout = nn.Dropout(classifier_dropout)
  1226. self.score = nn.Linear(config.hidden_size, config.num_labels)
  1227. # Initialize weights and apply final processing
  1228. self.post_init()
  1229. def get_input_embeddings(self):
  1230. return self.model.embed_tokens
  1231. def set_input_embeddings(self, value):
  1232. self.model.embed_tokens = value
  1233. @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
  1234. @add_code_sample_docstrings(
  1235. checkpoint=_CHECKPOINT_FOR_DOC,
  1236. output_type=TokenClassifierOutput,
  1237. config_class=_CONFIG_FOR_DOC,
  1238. )
  1239. def forward(
  1240. self,
  1241. input_ids: Optional[torch.LongTensor] = None,
  1242. attention_mask: Optional[torch.Tensor] = None,
  1243. position_ids: Optional[torch.LongTensor] = None,
  1244. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1245. inputs_embeds: Optional[torch.FloatTensor] = None,
  1246. labels: Optional[torch.LongTensor] = None,
  1247. use_cache: Optional[bool] = None,
  1248. output_attentions: Optional[bool] = None,
  1249. output_hidden_states: Optional[bool] = None,
  1250. return_dict: Optional[bool] = None,
  1251. ) -> Union[Tuple, TokenClassifierOutput]:
  1252. r"""
  1253. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1254. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1255. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1256. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1257. """
  1258. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1259. outputs = self.model(
  1260. input_ids,
  1261. attention_mask=attention_mask,
  1262. position_ids=position_ids,
  1263. past_key_values=past_key_values,
  1264. inputs_embeds=inputs_embeds,
  1265. use_cache=use_cache,
  1266. output_attentions=output_attentions,
  1267. output_hidden_states=output_hidden_states,
  1268. return_dict=return_dict,
  1269. )
  1270. sequence_output = outputs[0]
  1271. sequence_output = self.dropout(sequence_output)
  1272. logits = self.score(sequence_output)
  1273. loss = None
  1274. if labels is not None:
  1275. loss = self.loss_function(logits, labels, self.config)
  1276. if not return_dict:
  1277. output = (logits,) + outputs[2:]
  1278. return ((loss,) + output) if loss is not None else output
  1279. return TokenClassifierOutput(
  1280. loss=loss,
  1281. logits=logits,
  1282. hidden_states=outputs.hidden_states,
  1283. attentions=outputs.attentions,
  1284. )
  1285. @add_start_docstrings(
  1286. """
  1287. The Mixtral Model transformer with a span classification head on top for extractive question-answering tasks like
  1288. SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1289. """,
  1290. MIXTRAL_START_DOCSTRING,
  1291. )
  1292. # Copied from transformers.models.mistral.modeling_mistral.MistralForQuestionAnswering with Mistral->Mixtral, MISTRAL->MIXTRAL
  1293. class MixtralForQuestionAnswering(MixtralPreTrainedModel):
  1294. base_model_prefix = "model"
  1295. # Copied from models.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Mixtral
  1296. def __init__(self, config):
  1297. super().__init__(config)
  1298. self.model = MixtralModel(config)
  1299. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1300. # Initialize weights and apply final processing
  1301. self.post_init()
  1302. def get_input_embeddings(self):
  1303. return self.model.embed_tokens
  1304. def set_input_embeddings(self, value):
  1305. self.model.embed_tokens = value
  1306. @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
  1307. def forward(
  1308. self,
  1309. input_ids: Optional[torch.LongTensor] = None,
  1310. attention_mask: Optional[torch.FloatTensor] = None,
  1311. position_ids: Optional[torch.LongTensor] = None,
  1312. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1313. inputs_embeds: Optional[torch.FloatTensor] = None,
  1314. start_positions: Optional[torch.LongTensor] = None,
  1315. end_positions: Optional[torch.LongTensor] = None,
  1316. output_attentions: Optional[bool] = None,
  1317. output_hidden_states: Optional[bool] = None,
  1318. return_dict: Optional[bool] = None,
  1319. **kwargs,
  1320. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1321. r"""
  1322. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1323. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1324. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1325. are not taken into account for computing the loss.
  1326. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1327. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1328. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1329. are not taken into account for computing the loss.
  1330. """
  1331. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1332. outputs = self.model(
  1333. input_ids,
  1334. attention_mask=attention_mask,
  1335. position_ids=position_ids,
  1336. past_key_values=past_key_values,
  1337. inputs_embeds=inputs_embeds,
  1338. output_attentions=output_attentions,
  1339. output_hidden_states=output_hidden_states,
  1340. return_dict=return_dict,
  1341. )
  1342. sequence_output = outputs[0]
  1343. logits = self.qa_outputs(sequence_output)
  1344. start_logits, end_logits = logits.split(1, dim=-1)
  1345. start_logits = start_logits.squeeze(-1).contiguous()
  1346. end_logits = end_logits.squeeze(-1).contiguous()
  1347. loss = None
  1348. if start_positions is not None and end_positions is not None:
  1349. loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
  1350. if not return_dict:
  1351. output = (start_logits, end_logits) + outputs[2:]
  1352. return ((loss,) + output) if loss is not None else output
  1353. return QuestionAnsweringModelOutput(
  1354. loss=loss,
  1355. start_logits=start_logits,
  1356. end_logits=end_logits,
  1357. hidden_states=outputs.hidden_states,
  1358. attentions=outputs.attentions,
  1359. )