modeling_dbrx.py 61 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. # coding=utf-8
  2. # Copyright 2024 Databricks Mosaic Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DBRX model."""
  16. import math
  17. from typing import Any, Optional, Tuple, Union
  18. import torch
  19. import torch.utils.checkpoint
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, StaticCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_attn_mask_utils import AttentionMaskConverter
  25. from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. add_start_docstrings,
  29. add_start_docstrings_to_model_forward,
  30. is_flash_attn_2_available,
  31. is_flash_attn_greater_or_equal_2_10,
  32. logging,
  33. replace_return_docstrings,
  34. )
  35. from .configuration_dbrx import DbrxConfig
  36. if is_flash_attn_2_available():
  37. from ...modeling_flash_attention_utils import _flash_attention_forward
  38. logger = logging.get_logger(__name__)
  39. _CONFIG_FOR_DOC = "DbrxConfig"
  40. # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with Gemma->Dbrx
  41. class DbrxRotaryEmbedding(nn.Module):
  42. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  43. super().__init__()
  44. self.dim = dim
  45. self.max_position_embeddings = max_position_embeddings
  46. self.base = base
  47. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  48. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  49. @torch.no_grad()
  50. def forward(self, x, position_ids, seq_len=None):
  51. # x: [bs, num_attention_heads, seq_len, head_size]
  52. self.inv_freq.to(x.device)
  53. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  54. position_ids_expanded = position_ids[:, None, :].float()
  55. # Force float32 since bfloat16 loses precision on long contexts
  56. # See https://github.com/huggingface/transformers/pull/29285
  57. device_type = x.device.type
  58. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  59. with torch.autocast(device_type=device_type, enabled=False):
  60. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  61. emb = torch.cat((freqs, freqs), dim=-1)
  62. cos = emb.cos()
  63. sin = emb.sin()
  64. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  65. # Copied from transformers.models.llama.modeling_llama.rotate_half
  66. def rotate_half(x):
  67. """Rotates half the hidden dims of the input."""
  68. x1 = x[..., : x.shape[-1] // 2]
  69. x2 = x[..., x.shape[-1] // 2 :]
  70. return torch.cat((-x2, x1), dim=-1)
  71. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  72. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  73. """Applies Rotary Position Embedding to the query and key tensors.
  74. Args:
  75. q (`torch.Tensor`): The query tensor.
  76. k (`torch.Tensor`): The key tensor.
  77. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  78. sin (`torch.Tensor`): The sine part of the rotary embedding.
  79. position_ids (`torch.Tensor`, *optional*):
  80. Deprecated and unused.
  81. unsqueeze_dim (`int`, *optional*, defaults to 1):
  82. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  83. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  84. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  85. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  86. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  87. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  88. Returns:
  89. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  90. """
  91. cos = cos.unsqueeze(unsqueeze_dim)
  92. sin = sin.unsqueeze(unsqueeze_dim)
  93. q_embed = (q * cos) + (rotate_half(q) * sin)
  94. k_embed = (k * cos) + (rotate_half(k) * sin)
  95. return q_embed, k_embed
  96. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  97. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  98. """
  99. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  100. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  101. """
  102. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  103. if n_rep == 1:
  104. return hidden_states
  105. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  106. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  107. def load_balancing_loss_func(
  108. gate_logits: torch.Tensor,
  109. num_experts: int,
  110. top_k: int,
  111. attention_mask: Optional[torch.Tensor],
  112. ) -> torch.Tensor:
  113. r"""Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
  114. See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
  115. function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
  116. experts is too unbalanced.
  117. Args:
  118. gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]):
  119. Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of
  120. shape [batch_size X sequence_length, num_experts].
  121. num_experts (`int`):
  122. Number of experts.
  123. top_k (`int`):
  124. The number of experts each token is routed to.
  125. attention_mask (`torch.Tensor`, *optional*):
  126. The attention_mask used in forward function
  127. shape [batch_size X sequence_length] if not None.
  128. Returns:
  129. The auxiliary loss.
  130. """
  131. if gate_logits is None or not isinstance(gate_logits, tuple):
  132. return torch.tensor(0.0)
  133. if isinstance(gate_logits, tuple):
  134. compute_device = gate_logits[0].device
  135. concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
  136. routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
  137. _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
  138. expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
  139. if attention_mask is None:
  140. # Compute the percentage of tokens routed to each experts
  141. tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
  142. # Compute the average probability of routing to these experts
  143. router_prob_per_expert = torch.mean(routing_weights, dim=0)
  144. else:
  145. batch_size, sequence_length = attention_mask.shape
  146. num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
  147. # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
  148. expert_attention_mask = (
  149. attention_mask[None, :, :, None, None]
  150. .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
  151. .reshape(-1, top_k, num_experts)
  152. .to(compute_device)
  153. )
  154. # Compute the percentage of tokens routed to each experts
  155. tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
  156. expert_attention_mask, dim=0
  157. )
  158. # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
  159. router_per_expert_attention_mask = (
  160. attention_mask[None, :, :, None]
  161. .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
  162. .reshape(-1, num_experts)
  163. .to(compute_device)
  164. )
  165. # Compute the average probability of routing to these experts
  166. router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
  167. router_per_expert_attention_mask, dim=0
  168. )
  169. overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
  170. return overall_loss * num_experts
  171. class DbrxAttention(nn.Module):
  172. """Multi-head self attention."""
  173. def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
  174. super().__init__()
  175. self.config = config
  176. self.hidden_size = config.d_model
  177. self.num_heads = config.n_heads
  178. self.head_dim = self.hidden_size // self.num_heads
  179. self.max_position_embeddings = config.max_seq_len
  180. self.block_idx = block_idx
  181. if block_idx is None:
  182. logger.warning_once(
  183. f"Instantiating {self.__class__.__name__} without passing a `block_idx` is not recommended and will "
  184. + "lead to errors during the forward call if caching is used. Please make sure to provide a `block_idx` "
  185. + "when creating this class."
  186. )
  187. attn_config = config.attn_config
  188. self.attn_pdrop = attn_config.attn_pdrop
  189. self.clip_qkv = attn_config.clip_qkv
  190. self.num_key_value_heads = attn_config.kv_n_heads
  191. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  192. self.rope_theta = attn_config.rope_theta
  193. self.is_causal = True
  194. self.Wqkv = nn.Linear(
  195. self.hidden_size, self.hidden_size + 2 * self.num_key_value_heads * self.head_dim, bias=False
  196. )
  197. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  198. self.rotary_emb = DbrxRotaryEmbedding(
  199. self.head_dim,
  200. max_position_embeddings=self.max_position_embeddings,
  201. base=self.rope_theta,
  202. )
  203. def forward(
  204. self,
  205. hidden_states: torch.Tensor,
  206. position_ids: torch.LongTensor,
  207. attention_mask: Optional[torch.Tensor] = None,
  208. past_key_value: Optional[Cache] = None,
  209. output_attentions: bool = False,
  210. use_cache: bool = False,
  211. cache_position: Optional[torch.LongTensor] = None,
  212. **kwargs: Any,
  213. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
  214. bsz, q_len, _ = hidden_states.size()
  215. qkv_states = self.Wqkv(hidden_states)
  216. min_val = -self.clip_qkv if self.clip_qkv is not None else None
  217. max_val = self.clip_qkv
  218. qkv_states = qkv_states.clamp(min=min_val, max=max_val)
  219. query_states, key_states, value_states = qkv_states.split(
  220. [
  221. self.hidden_size,
  222. self.num_key_value_heads * self.head_dim,
  223. self.num_key_value_heads * self.head_dim,
  224. ],
  225. dim=2,
  226. )
  227. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  228. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  229. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  230. cos, sin = self.rotary_emb(value_states, position_ids)
  231. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  232. if past_key_value is not None:
  233. # sin and cos are specific to RoPE models; position_ids needed for the static cache
  234. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  235. key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
  236. key_states = repeat_kv(key_states, self.num_key_value_groups)
  237. value_states = repeat_kv(value_states, self.num_key_value_groups)
  238. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  239. if attention_mask is not None: # no matter the length, we just slice it
  240. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  241. attn_weights = attn_weights + causal_mask
  242. # upcast attention to fp32
  243. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  244. attn_weights = nn.functional.dropout(attn_weights, p=self.attn_pdrop, training=self.training)
  245. attn_output = torch.matmul(attn_weights, value_states)
  246. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  247. raise ValueError(
  248. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  249. + f" {attn_output.size()}"
  250. )
  251. attn_output = attn_output.transpose(1, 2).contiguous()
  252. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  253. attn_output = self.out_proj(attn_output)
  254. if not output_attentions:
  255. attn_weights = None
  256. return attn_output, attn_weights, past_key_value
  257. class DbrxFlashAttention2(DbrxAttention):
  258. """Dbrx flash attention module.
  259. This module inherits from `DbrxAttention` as the weights of the module stays
  260. untouched. The only required change would be on the forward pass where it
  261. calls the public API of flash attention.
  262. """
  263. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  264. def __init__(self, *args, **kwargs):
  265. super().__init__(*args, **kwargs)
  266. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  267. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  268. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  269. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  270. def forward(
  271. self,
  272. hidden_states: torch.Tensor,
  273. attention_mask: Optional[torch.LongTensor] = None,
  274. position_ids: Optional[torch.LongTensor] = None,
  275. past_key_value: Optional[Cache] = None,
  276. output_attentions: bool = False,
  277. use_cache: bool = False,
  278. cache_position: Optional[torch.LongTensor] = None,
  279. **kwargs: Any,
  280. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  281. if isinstance(past_key_value, StaticCache):
  282. raise ValueError(
  283. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  284. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  285. )
  286. logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
  287. output_attentions = False
  288. bsz, q_len, _ = hidden_states.size()
  289. qkv_states = self.Wqkv(hidden_states)
  290. if self.clip_qkv is not None:
  291. qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
  292. query_states, key_states, value_states = qkv_states.split(
  293. [
  294. self.hidden_size,
  295. self.num_key_value_heads * self.head_dim,
  296. self.num_key_value_heads * self.head_dim,
  297. ],
  298. dim=2,
  299. )
  300. # Flash attention requires the input to have the shape
  301. # batch_size x seq_length x head_dim x hidden_dim
  302. # therefore we just need to keep the original shape
  303. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  304. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  305. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  306. cos, sin = self.rotary_emb(value_states, position_ids)
  307. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  308. if past_key_value is not None:
  309. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  310. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  311. key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
  312. # TODO: These transpose are quite inefficient but Flash Attention requires the layout
  313. # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  314. # to be able to avoid many of these transpose/reshape/view.
  315. query_states = query_states.transpose(1, 2)
  316. key_states = key_states.transpose(1, 2)
  317. value_states = value_states.transpose(1, 2)
  318. dropout_rate = self.attn_pdrop if self.training else 0.0
  319. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  320. # therefore the input hidden states gets silently casted in float32. Hence, we need
  321. # cast them back in the correct dtype just to be sure everything works as expected.
  322. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  323. # in fp32. (LlamaRMSNorm handles it correctly)
  324. input_dtype = query_states.dtype
  325. if input_dtype == torch.float32:
  326. if torch.is_autocast_enabled():
  327. target_dtype = torch.get_autocast_gpu_dtype()
  328. # Handle the case where the model is quantized
  329. elif hasattr(self.config, "_pre_quantization_dtype"):
  330. target_dtype = self.config._pre_quantization_dtype
  331. else:
  332. target_dtype = query_states.dtype
  333. logger.warning_once(
  334. "The input hidden states seems to be silently casted in float32, this might be "
  335. + "related to the fact you have upcasted embedding or layer norm layers in "
  336. + f"float32. We will cast back the input in {target_dtype}."
  337. )
  338. query_states = query_states.to(target_dtype)
  339. key_states = key_states.to(target_dtype)
  340. value_states = value_states.to(target_dtype)
  341. attn_output = _flash_attention_forward(
  342. query_states,
  343. key_states,
  344. value_states,
  345. attention_mask,
  346. q_len,
  347. position_ids=position_ids,
  348. dropout=dropout_rate,
  349. is_causal=self.is_causal,
  350. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  351. )
  352. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  353. attn_output = self.out_proj(attn_output)
  354. if not output_attentions:
  355. attn_weights = None
  356. return attn_output, attn_weights, past_key_value
  357. class DbrxSdpaAttention(DbrxAttention):
  358. """
  359. Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  360. `DbrxAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  361. SDPA API.
  362. """
  363. def forward(
  364. self,
  365. hidden_states: torch.Tensor,
  366. attention_mask: Optional[torch.Tensor] = None,
  367. position_ids: Optional[torch.LongTensor] = None,
  368. past_key_value: Optional[Cache] = None,
  369. output_attentions: bool = False,
  370. use_cache: bool = False,
  371. cache_position: Optional[torch.LongTensor] = None,
  372. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  373. if output_attentions:
  374. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  375. logger.warning_once(
  376. "DbrxModel is using DbrxSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  377. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  378. )
  379. return super().forward(
  380. hidden_states=hidden_states,
  381. attention_mask=attention_mask,
  382. position_ids=position_ids,
  383. past_key_value=past_key_value,
  384. output_attentions=output_attentions,
  385. use_cache=use_cache,
  386. cache_position=cache_position,
  387. )
  388. bsz, q_len, _ = hidden_states.size()
  389. qkv_states = self.Wqkv(hidden_states)
  390. if self.clip_qkv is not None:
  391. qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
  392. query_states, key_states, value_states = qkv_states.split(
  393. [
  394. self.hidden_size,
  395. self.num_key_value_heads * self.head_dim,
  396. self.num_key_value_heads * self.head_dim,
  397. ],
  398. dim=2,
  399. )
  400. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  401. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  402. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  403. cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
  404. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None)
  405. if past_key_value is not None:
  406. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  407. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  408. key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
  409. key_states = repeat_kv(key_states, self.num_key_value_groups)
  410. value_states = repeat_kv(value_states, self.num_key_value_groups)
  411. causal_mask = attention_mask
  412. if attention_mask is not None:
  413. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  414. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  415. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  416. if query_states.device.type == "cuda" and causal_mask is not None:
  417. query_states = query_states.contiguous()
  418. key_states = key_states.contiguous()
  419. value_states = value_states.contiguous()
  420. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  421. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  422. is_causal = True if causal_mask is None and q_len > 1 else False
  423. attn_output = torch.nn.functional.scaled_dot_product_attention(
  424. query_states,
  425. key_states,
  426. value_states,
  427. attn_mask=causal_mask,
  428. dropout_p=self.attn_pdrop if self.training else 0.0,
  429. is_causal=is_causal,
  430. )
  431. attn_output = attn_output.transpose(1, 2).contiguous()
  432. attn_output = attn_output.view(bsz, q_len, -1)
  433. attn_output = self.out_proj(attn_output)
  434. return attn_output, None, past_key_value
  435. DBRX_ATTENTION_CLASSES = {
  436. "eager": DbrxAttention,
  437. "flash_attention_2": DbrxFlashAttention2,
  438. "sdpa": DbrxSdpaAttention,
  439. }
  440. class DbrxNormAttentionNorm(nn.Module):
  441. def __init__(self, config: DbrxConfig, block_idx: Optional[int] = None):
  442. super().__init__()
  443. self.block_idx = block_idx
  444. self.resid_pdrop = config.resid_pdrop
  445. self.norm_1 = nn.LayerNorm(config.d_model, bias=False)
  446. self.attn = DBRX_ATTENTION_CLASSES[config._attn_implementation](
  447. config=config,
  448. block_idx=block_idx,
  449. )
  450. self.norm_2 = nn.LayerNorm(config.d_model, bias=False)
  451. def forward(
  452. self,
  453. hidden_states: torch.Tensor,
  454. position_ids: torch.LongTensor,
  455. attention_mask: Optional[torch.Tensor] = None,
  456. past_key_value: Optional[Cache] = None,
  457. output_attentions: bool = False,
  458. use_cache: bool = False,
  459. cache_position: Optional[torch.LongTensor] = None,
  460. **kwargs: Any,
  461. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
  462. residual_states = hidden_states
  463. hidden_states = self.norm_1(hidden_states).to(hidden_states.dtype)
  464. hidden_states, attn_weights, past_key_value = self.attn(
  465. hidden_states=hidden_states,
  466. attention_mask=attention_mask,
  467. position_ids=position_ids,
  468. past_key_value=past_key_value,
  469. output_attentions=output_attentions,
  470. use_cache=use_cache,
  471. cache_position=cache_position,
  472. **kwargs,
  473. )
  474. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  475. hidden_states = hidden_states + residual_states
  476. residual_states = hidden_states
  477. hidden_states = self.norm_2(hidden_states).to(hidden_states.dtype)
  478. return residual_states, hidden_states, attn_weights, past_key_value
  479. class DbrxRouter(nn.Module):
  480. def __init__(
  481. self,
  482. hidden_size: int,
  483. moe_num_experts: int,
  484. moe_top_k: int,
  485. moe_jitter_eps: Optional[float],
  486. moe_normalize_expert_weights: Optional[float],
  487. ):
  488. super().__init__()
  489. self.hidden_size = hidden_size
  490. self.moe_num_experts = moe_num_experts
  491. self.moe_top_k = moe_top_k
  492. self.moe_jitter_eps = moe_jitter_eps
  493. self.moe_normalize_expert_weights = moe_normalize_expert_weights
  494. self.layer = nn.Linear(self.hidden_size, self.moe_num_experts, bias=False)
  495. def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
  496. if self.training and self.moe_jitter_eps is not None:
  497. hidden_states *= torch.empty_like(hidden_states).uniform_(
  498. 1.0 - self.moe_jitter_eps, 1.0 + self.moe_jitter_eps
  499. )
  500. hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
  501. weights = self.layer(hidden_states).softmax(dim=-1, dtype=torch.float32)
  502. top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
  503. top_weights_scale = (
  504. torch.norm(top_weights, p=self.moe_normalize_expert_weights, dim=-1, keepdim=True)
  505. if self.moe_normalize_expert_weights is not None
  506. else 1.0
  507. )
  508. top_weights = top_weights / top_weights_scale
  509. weights = weights.to(hidden_states.dtype)
  510. top_weights = top_weights.to(hidden_states.dtype)
  511. return weights, top_weights, top_experts
  512. class DbrxExpertGLU(nn.Module):
  513. def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
  514. super().__init__()
  515. self.hidden_size = hidden_size
  516. self.ffn_hidden_size = ffn_hidden_size
  517. self.moe_num_experts = moe_num_experts
  518. self.w1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
  519. self.v1 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
  520. self.w2 = nn.Parameter(torch.empty(moe_num_experts * ffn_hidden_size, hidden_size))
  521. act_fn_name = ffn_act_fn.get("name", "silu")
  522. self.activation_fn = ACT2FN[act_fn_name]
  523. def forward(
  524. self, x: torch.Tensor, expert_w1: torch.Tensor, expert_v1: torch.Tensor, expert_w2: torch.Tensor
  525. ) -> torch.Tensor:
  526. gate_proj = x.matmul(expert_w1.t())
  527. up_proj = x.matmul(expert_v1.t())
  528. gate_proj = self.activation_fn(gate_proj)
  529. intermediate_states = gate_proj * up_proj
  530. down_proj = intermediate_states.matmul(expert_w2)
  531. return down_proj
  532. class DbrxExperts(nn.Module):
  533. def __init__(self, hidden_size: int, ffn_hidden_size: int, moe_num_experts: int, ffn_act_fn: dict):
  534. super().__init__()
  535. self.moe_num_experts = moe_num_experts
  536. self.mlp = DbrxExpertGLU(
  537. hidden_size=hidden_size,
  538. ffn_hidden_size=ffn_hidden_size,
  539. moe_num_experts=moe_num_experts,
  540. ffn_act_fn=ffn_act_fn,
  541. )
  542. def forward(
  543. self, x: torch.Tensor, weights: torch.Tensor, top_weights: torch.Tensor, top_experts: torch.LongTensor
  544. ) -> torch.Tensor:
  545. bsz, q_len, hidden_size = x.shape
  546. x = x.view(-1, hidden_size)
  547. out = torch.zeros_like(x)
  548. expert_mask = nn.functional.one_hot(top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
  549. # Chunk experts at once to avoid storing full parameter multiple times in autograd
  550. w1_chunked = self.mlp.w1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
  551. self.moe_num_experts, dim=0
  552. )
  553. v1_chunked = self.mlp.v1.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
  554. self.moe_num_experts, dim=0
  555. )
  556. w2_chunked = self.mlp.w2.view(self.mlp.moe_num_experts, self.mlp.ffn_hidden_size, self.mlp.hidden_size).chunk(
  557. self.moe_num_experts, dim=0
  558. )
  559. w1_chunked = [w1.squeeze(dim=0) for w1 in w1_chunked]
  560. v1_chunked = [v1.squeeze(dim=0) for v1 in v1_chunked]
  561. w2_chunked = [w2.squeeze(dim=0) for w2 in w2_chunked]
  562. for expert_idx in range(0, self.moe_num_experts):
  563. topk_idx, token_idx = torch.where(expert_mask[expert_idx])
  564. if token_idx.shape[0] == 0:
  565. continue
  566. token_list = token_idx
  567. topk_list = topk_idx
  568. expert_tokens = x[None, token_list].reshape(-1, hidden_size)
  569. expert_out = (
  570. self.mlp(expert_tokens, w1_chunked[expert_idx], v1_chunked[expert_idx], w2_chunked[expert_idx])
  571. * top_weights[token_list, topk_list, None]
  572. )
  573. out.index_add_(0, token_idx, expert_out)
  574. out = out.reshape(bsz, q_len, hidden_size)
  575. return out
  576. class DbrxFFN(nn.Module):
  577. def __init__(self, config: DbrxConfig):
  578. super().__init__()
  579. ffn_config = config.ffn_config
  580. self.router = DbrxRouter(
  581. hidden_size=config.d_model,
  582. moe_num_experts=ffn_config.moe_num_experts,
  583. moe_top_k=ffn_config.moe_top_k,
  584. moe_jitter_eps=ffn_config.moe_jitter_eps,
  585. moe_normalize_expert_weights=ffn_config.moe_normalize_expert_weights,
  586. )
  587. self.experts = DbrxExperts(
  588. hidden_size=config.d_model,
  589. ffn_hidden_size=ffn_config.ffn_hidden_size,
  590. moe_num_experts=ffn_config.moe_num_experts,
  591. ffn_act_fn=ffn_config.ffn_act_fn,
  592. )
  593. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  594. weights, top_weights, top_experts = self.router(x)
  595. out = self.experts(x, weights, top_weights, top_experts)
  596. return out, weights
  597. class DbrxBlock(nn.Module):
  598. def __init__(self, config: DbrxConfig, block_idx: int):
  599. super().__init__()
  600. self.hidden_size = config.d_model
  601. self.resid_pdrop = config.resid_pdrop
  602. self.block_idx = block_idx
  603. self.norm_attn_norm = DbrxNormAttentionNorm(
  604. config=config,
  605. block_idx=block_idx,
  606. )
  607. self.ffn = DbrxFFN(config=config)
  608. def forward(
  609. self,
  610. hidden_states: torch.Tensor,
  611. attention_mask: Optional[torch.Tensor] = None,
  612. position_ids: torch.LongTensor = None,
  613. past_key_value: Optional[Cache] = None,
  614. output_attentions: Optional[bool] = False,
  615. output_router_logits: Optional[bool] = False,
  616. use_cache: Optional[bool] = False,
  617. cache_position: Optional[torch.LongTensor] = None,
  618. **kwargs: Any,
  619. ) -> Union[
  620. Tuple[torch.Tensor],
  621. Tuple[torch.Tensor, Optional[torch.Tensor]],
  622. Tuple[torch.Tensor, Optional[Cache]],
  623. Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]],
  624. Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]],
  625. Tuple[torch.Tensor, Optional[Cache], Optional[torch.Tensor]],
  626. Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor]],
  627. ]:
  628. """Forward function for DbrxBlock.
  629. Args:
  630. hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  631. position_ids (`torch.LongTensor`): position ids of shape `(batch, seq_len)`
  632. attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length)
  633. if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length)
  634. if default attention is used.
  635. past_key_value (`Tuple(torch.Tensor)`, *optional*): cached past key and value projection states
  636. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all
  637. attention layers. See `attentions` under returned tensors for more detail.
  638. output_router_logits (`bool`, *optional*): Whether or not to return the router logits.
  639. use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are
  640. returned and can be used to speed up decoding (see `past_key_values`).
  641. cache_position (`torch.LongTensor`, *optional*): position ids of the cache
  642. """
  643. # Norm + Attention + Norm
  644. resid_states, hidden_states, self_attn_weights, present_key_value = self.norm_attn_norm(
  645. hidden_states=hidden_states,
  646. attention_mask=attention_mask,
  647. position_ids=position_ids,
  648. past_key_value=past_key_value,
  649. output_attentions=output_attentions,
  650. use_cache=use_cache,
  651. cache_position=cache_position,
  652. **kwargs,
  653. )
  654. # Fully Connected
  655. hidden_states, router_logits = self.ffn(hidden_states)
  656. hidden_states = nn.functional.dropout(hidden_states, p=self.resid_pdrop, training=self.training)
  657. hidden_states = resid_states + hidden_states
  658. outputs = (hidden_states,)
  659. if output_attentions:
  660. outputs += (self_attn_weights,)
  661. if use_cache:
  662. outputs += (present_key_value,)
  663. if output_router_logits:
  664. outputs += (router_logits,)
  665. return outputs
  666. DBRX_START_DOCSTRING = r"""
  667. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  668. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  669. etc.)
  670. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  671. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  672. and behavior.
  673. Parameters:
  674. config ([`DbrxConfig`]):
  675. Model configuration class with all the parameters of the model. Initializing with a config file does not
  676. load the weights associated with the model, only the configuration. Check out the
  677. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  678. """
  679. @add_start_docstrings(
  680. "The bare DBRX Model outputting raw hidden-states without any specific head on top.",
  681. DBRX_START_DOCSTRING,
  682. )
  683. class DbrxPreTrainedModel(PreTrainedModel):
  684. config_class = DbrxConfig
  685. base_model_prefix = "transformer"
  686. supports_gradient_checkpointing = True
  687. _no_split_modules = ["DbrxBlock"]
  688. _skip_keys_device_placement = ["past_key_values"]
  689. _supports_flash_attn_2 = True
  690. _supports_sdpa = True
  691. _supports_cache_class = True
  692. _supports_quantized_cache = True
  693. _supports_static_cache = True
  694. def _init_weights(self, module: nn.Module):
  695. std = self.config.initializer_range
  696. if isinstance(module, nn.Linear):
  697. module.weight.data.normal_(mean=0.0, std=std)
  698. if module.bias is not None:
  699. module.bias.data.zero_()
  700. elif isinstance(module, nn.Embedding):
  701. module.weight.data.normal_(mean=0.0, std=std)
  702. if module.padding_idx is not None:
  703. module.weight.data[module.padding_idx].zero_()
  704. elif isinstance(module, nn.LayerNorm):
  705. module.weight.data.normal_(mean=0.0, std=std)
  706. if module.bias is not None:
  707. module.bias.data.zero_()
  708. elif isinstance(module, DbrxExpertGLU):
  709. module.w1.data.normal_(mean=0.0, std=std)
  710. module.v1.data.normal_(mean=0.0, std=std)
  711. module.w2.data.normal_(mean=0.0, std=std)
  712. DBRX_INPUTS_DOCSTRING = r"""
  713. Args:
  714. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  715. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  716. it.
  717. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  718. [`PreTrainedTokenizer.__call__`] for details.
  719. [What are input IDs?](../glossary#input-ids)
  720. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  721. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  722. - 1 for tokens that are **not masked**,
  723. - 0 for tokens that are **masked**.
  724. [What are attention masks?](../glossary#attention-mask)
  725. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  726. [`PreTrainedTokenizer.__call__`] for details.
  727. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  728. `past_key_values`).
  729. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  730. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  731. information on the default strategy.
  732. - 1 indicates the head is **not masked**,
  733. - 0 indicates the head is **masked**.
  734. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  735. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  736. config.n_positions - 1]`.
  737. [What are position IDs?](../glossary#position-ids)
  738. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  739. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  740. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  741. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  742. Two formats are allowed:
  743. - a [`~cache_utils.Cache`] instance, see our
  744. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  745. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  746. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  747. cache format.
  748. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  749. legacy cache format will be returned.
  750. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  751. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  752. of shape `(batch_size, sequence_length)`.
  753. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  754. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  755. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  756. model's internal embedding lookup matrix.
  757. use_cache (`bool`, *optional*):
  758. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  759. `past_key_values`).
  760. output_attentions (`bool`, *optional*):
  761. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  762. tensors for more detail.
  763. output_hidden_states (`bool`, *optional*):
  764. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  765. more detail.
  766. output_router_logits (`bool`, *optional*):
  767. Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
  768. should not be returned during inference.
  769. return_dict (`bool`, *optional*):
  770. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  771. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  772. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  773. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  774. the complete sequence length.
  775. """
  776. @add_start_docstrings(
  777. "The bare DBRX Model outputting raw hidden-states without any specific head on top.",
  778. DBRX_START_DOCSTRING,
  779. )
  780. class DbrxModel(DbrxPreTrainedModel):
  781. """Transformer decoder consisting of *config.num_hidden_layers*. Each layer is a [`DbrxBlock`] layer.
  782. Args:
  783. config ([`DbrxConfig`]): Model configuration class with all parameters of the model.
  784. Initializing with a config file does not load the weights associated with the model, only the
  785. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  786. """
  787. def __init__(self, config: DbrxConfig):
  788. super().__init__(config)
  789. self.padding_idx = config.pad_token_id
  790. self.vocab_size = config.vocab_size
  791. self.emb_pdrop = config.emb_pdrop
  792. self.wte = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  793. self.blocks = nn.ModuleList([DbrxBlock(config, block_idx) for block_idx in range(config.n_layers)])
  794. self.norm_f = nn.LayerNorm(config.d_model, bias=False)
  795. self.gradient_checkpointing = False
  796. # Initialize weights and apply final processing
  797. self.post_init()
  798. def get_input_embeddings(self) -> nn.Embedding:
  799. return self.wte
  800. def set_input_embeddings(self, value: nn.Embedding):
  801. self.wte = value
  802. @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
  803. def forward(
  804. self,
  805. input_ids: Optional[torch.LongTensor] = None,
  806. attention_mask: Optional[torch.Tensor] = None,
  807. position_ids: Optional[torch.LongTensor] = None,
  808. past_key_values: Optional[Cache] = None,
  809. inputs_embeds: Optional[torch.Tensor] = None,
  810. use_cache: Optional[bool] = None,
  811. output_attentions: Optional[bool] = None,
  812. output_hidden_states: Optional[bool] = None,
  813. output_router_logits: Optional[bool] = None,
  814. return_dict: Optional[bool] = None,
  815. cache_position: Optional[torch.LongTensor] = None,
  816. ) -> Union[Tuple, MoeModelOutputWithPast]:
  817. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  818. output_hidden_states = (
  819. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  820. )
  821. output_router_logits = (
  822. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  823. )
  824. use_cache = use_cache if use_cache is not None else self.config.use_cache
  825. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  826. if (input_ids is None) ^ (inputs_embeds is not None):
  827. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  828. if self.gradient_checkpointing and self.training and use_cache:
  829. logger.warning_once(
  830. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  831. )
  832. use_cache = False
  833. if inputs_embeds is None:
  834. inputs_embeds = self.wte(input_ids)
  835. inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
  836. # kept for BC (non `Cache` `past_key_values` inputs)
  837. return_legacy_cache = False
  838. if use_cache and not isinstance(past_key_values, Cache):
  839. return_legacy_cache = True
  840. if past_key_values is None:
  841. past_key_values = DynamicCache()
  842. else:
  843. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  844. logger.warning_once(
  845. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  846. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  847. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  848. )
  849. if cache_position is None:
  850. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  851. cache_position = torch.arange(
  852. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  853. )
  854. if position_ids is None:
  855. position_ids = cache_position.unsqueeze(0)
  856. causal_mask = self._update_causal_mask(
  857. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  858. )
  859. # embed positions
  860. hidden_states = inputs_embeds
  861. # decoder layers
  862. all_hidden_states = () if output_hidden_states else None
  863. all_self_attns = () if output_attentions else None
  864. all_router_logits = () if output_router_logits else None
  865. next_decoder_cache = None
  866. for block in self.blocks:
  867. if output_hidden_states:
  868. all_hidden_states += (hidden_states,)
  869. if self.gradient_checkpointing and self.training:
  870. block_outputs = self._gradient_checkpointing_func(
  871. block.__call__,
  872. hidden_states,
  873. causal_mask,
  874. position_ids,
  875. past_key_values,
  876. output_attentions,
  877. output_router_logits,
  878. use_cache,
  879. cache_position,
  880. )
  881. else:
  882. block_outputs = block(
  883. hidden_states,
  884. attention_mask=causal_mask,
  885. position_ids=position_ids,
  886. past_key_value=past_key_values,
  887. output_attentions=output_attentions,
  888. output_router_logits=output_router_logits,
  889. use_cache=use_cache,
  890. cache_position=cache_position,
  891. )
  892. hidden_states = block_outputs[0]
  893. if use_cache:
  894. next_decoder_cache = block_outputs[2 if output_attentions else 1]
  895. if output_attentions:
  896. all_self_attns += (block_outputs[1],)
  897. if output_router_logits:
  898. all_router_logits += (block_outputs[-1],)
  899. hidden_states = self.norm_f(hidden_states)
  900. # add hidden states from the last decoder layer
  901. if output_hidden_states:
  902. all_hidden_states += (hidden_states,)
  903. next_cache = next_decoder_cache if use_cache else None
  904. if return_legacy_cache:
  905. next_cache = next_cache.to_legacy_cache()
  906. if not return_dict:
  907. return tuple(
  908. v
  909. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]
  910. if v is not None
  911. )
  912. return MoeModelOutputWithPast(
  913. last_hidden_state=hidden_states,
  914. past_key_values=next_cache,
  915. hidden_states=all_hidden_states,
  916. attentions=all_self_attns,
  917. router_logits=all_router_logits,
  918. )
  919. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  920. def _update_causal_mask(
  921. self,
  922. attention_mask: torch.Tensor,
  923. input_tensor: torch.Tensor,
  924. cache_position: torch.Tensor,
  925. past_key_values: Cache,
  926. output_attentions: bool,
  927. ):
  928. if self.config._attn_implementation == "flash_attention_2":
  929. if attention_mask is not None and 0.0 in attention_mask:
  930. return attention_mask
  931. return None
  932. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  933. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  934. # to infer the attention mask.
  935. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  936. using_static_cache = isinstance(past_key_values, StaticCache)
  937. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  938. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  939. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  940. attention_mask,
  941. inputs_embeds=input_tensor,
  942. past_key_values_length=past_seen_tokens,
  943. is_training=self.training,
  944. ):
  945. return None
  946. dtype, device = input_tensor.dtype, input_tensor.device
  947. sequence_length = input_tensor.shape[1]
  948. if using_static_cache:
  949. target_length = past_key_values.get_max_cache_shape()
  950. else:
  951. target_length = (
  952. attention_mask.shape[-1]
  953. if isinstance(attention_mask, torch.Tensor)
  954. else past_seen_tokens + sequence_length + 1
  955. )
  956. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  957. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  958. attention_mask,
  959. sequence_length=sequence_length,
  960. target_length=target_length,
  961. dtype=dtype,
  962. device=device,
  963. cache_position=cache_position,
  964. batch_size=input_tensor.shape[0],
  965. )
  966. if (
  967. self.config._attn_implementation == "sdpa"
  968. and attention_mask is not None
  969. and attention_mask.device.type == "cuda"
  970. and not output_attentions
  971. ):
  972. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  973. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  974. # Details: https://github.com/pytorch/pytorch/issues/110213
  975. min_dtype = torch.finfo(dtype).min
  976. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  977. return causal_mask
  978. @staticmethod
  979. # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
  980. def _prepare_4d_causal_attention_mask_with_cache_position(
  981. attention_mask: torch.Tensor,
  982. sequence_length: int,
  983. target_length: int,
  984. dtype: torch.dtype,
  985. device: torch.device,
  986. cache_position: torch.Tensor,
  987. batch_size: int,
  988. **kwargs,
  989. ):
  990. """
  991. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  992. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  993. Args:
  994. attention_mask (`torch.Tensor`):
  995. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  996. `(batch_size, 1, query_length, key_value_length)`.
  997. sequence_length (`int`):
  998. The sequence length being processed.
  999. target_length (`int`):
  1000. The target length: when generating with static cache, the mask should be as long as the static cache,
  1001. to account for the 0 padding, the part of the cache that is not filled yet.
  1002. dtype (`torch.dtype`):
  1003. The dtype to use for the 4D attention mask.
  1004. device (`torch.device`):
  1005. The device to plcae the 4D attention mask on.
  1006. cache_position (`torch.Tensor`):
  1007. Indices depicting the position of the input sequence tokens in the sequence.
  1008. batch_size (`torch.Tensor`):
  1009. Batch size.
  1010. """
  1011. if attention_mask is not None and attention_mask.dim() == 4:
  1012. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1013. causal_mask = attention_mask
  1014. else:
  1015. min_dtype = torch.finfo(dtype).min
  1016. causal_mask = torch.full(
  1017. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  1018. )
  1019. if sequence_length != 1:
  1020. causal_mask = torch.triu(causal_mask, diagonal=1)
  1021. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  1022. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1023. if attention_mask is not None:
  1024. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1025. mask_length = attention_mask.shape[-1]
  1026. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  1027. padding_mask = padding_mask == 0
  1028. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1029. padding_mask, min_dtype
  1030. )
  1031. return causal_mask
  1032. @add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING)
  1033. class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin):
  1034. def __init__(self, config: DbrxConfig):
  1035. super().__init__(config)
  1036. self.transformer = DbrxModel(config)
  1037. self.vocab_size = config.vocab_size
  1038. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1039. self.moe_loss_weight = config.ffn_config.moe_loss_weight
  1040. self.num_experts = config.ffn_config.moe_num_experts
  1041. self.num_experts_per_tok = config.ffn_config.moe_top_k
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. def get_input_embeddings(self) -> nn.Embedding:
  1045. return self.transformer.get_input_embeddings()
  1046. def set_input_embeddings(self, value: nn.Embedding):
  1047. self.transformer.set_input_embeddings(value)
  1048. def get_output_embeddings(self) -> nn.Linear:
  1049. return self.lm_head
  1050. def set_output_embeddings(self, new_embeddings: nn.Linear):
  1051. self.lm_head = new_embeddings
  1052. def set_decoder(self, decoder: DbrxModel):
  1053. self.transformer = decoder
  1054. def get_decoder(self) -> DbrxModel:
  1055. return self.transformer
  1056. @add_start_docstrings_to_model_forward(DBRX_INPUTS_DOCSTRING)
  1057. @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  1058. def forward(
  1059. self,
  1060. input_ids: Optional[torch.LongTensor] = None,
  1061. attention_mask: Optional[torch.Tensor] = None,
  1062. position_ids: Optional[torch.LongTensor] = None,
  1063. past_key_values: Optional[Cache] = None,
  1064. inputs_embeds: Optional[torch.Tensor] = None,
  1065. labels: Optional[torch.LongTensor] = None,
  1066. use_cache: Optional[bool] = None,
  1067. output_attentions: Optional[bool] = None,
  1068. output_hidden_states: Optional[bool] = None,
  1069. output_router_logits: Optional[bool] = None,
  1070. return_dict: Optional[bool] = None,
  1071. cache_position: Optional[torch.LongTensor] = None,
  1072. num_logits_to_keep: int = 0,
  1073. ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
  1074. r"""Forward function for causal language modeling.
  1075. Args:
  1076. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1077. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1078. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1079. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1080. num_logits_to_keep (`int`, *optional*):
  1081. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1082. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1083. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1084. Returns:
  1085. Example:
  1086. ```python
  1087. >> from transformers import AutoTokenizer, DbrxForCausalLM
  1088. >> model = DbrxForCausalLM.from_pretrained("databricks/dbrx-instruct")
  1089. >> tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx-instruct")
  1090. >> prompt = "Hey, are you conscious? Can you talk to me?"
  1091. >> inputs = tokenizer(prompt, return_tensors="pt")
  1092. >> # Generate
  1093. >> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1094. >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1095. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  1096. ```
  1097. """
  1098. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1099. output_hidden_states = (
  1100. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1101. )
  1102. output_router_logits = (
  1103. output_router_logits if output_router_logits is not None else self.config.output_router_logits
  1104. )
  1105. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1106. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1107. outputs = self.transformer(
  1108. input_ids=input_ids,
  1109. attention_mask=attention_mask,
  1110. position_ids=position_ids,
  1111. past_key_values=past_key_values,
  1112. inputs_embeds=inputs_embeds,
  1113. use_cache=use_cache,
  1114. output_attentions=output_attentions,
  1115. output_hidden_states=output_hidden_states,
  1116. output_router_logits=output_router_logits,
  1117. return_dict=return_dict,
  1118. cache_position=cache_position,
  1119. )
  1120. hidden_states = outputs[0]
  1121. # No upscaling to float was ever done for Dbrx
  1122. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  1123. loss = None
  1124. if labels is not None:
  1125. # Shift so that tokens < n predict n
  1126. shift_logits = logits[..., :-1, :].contiguous()
  1127. shift_labels = labels[..., 1:].contiguous()
  1128. # Flatten the tokens
  1129. loss_fct = nn.CrossEntropyLoss()
  1130. shift_logits = shift_logits.view(-1, self.config.vocab_size)
  1131. shift_labels = shift_labels.view(-1)
  1132. # Enable model parallelism
  1133. shift_labels = shift_labels.to(shift_logits.device)
  1134. loss = loss_fct(shift_logits, shift_labels)
  1135. aux_loss = None
  1136. if output_router_logits:
  1137. aux_loss = load_balancing_loss_func(
  1138. outputs.router_logits if return_dict else outputs[-1],
  1139. self.num_experts,
  1140. self.num_experts_per_tok,
  1141. attention_mask,
  1142. )
  1143. if labels is not None and loss is not None:
  1144. loss += self.moe_loss_weight * aux_loss.to(loss.device) # make sure to reside in the same device
  1145. if not return_dict:
  1146. output = (logits,) + outputs[1:]
  1147. if output_router_logits:
  1148. output = (aux_loss,) + output
  1149. return (loss,) + output if loss is not None else output
  1150. return MoeCausalLMOutputWithPast(
  1151. loss=loss,
  1152. aux_loss=aux_loss,
  1153. logits=logits,
  1154. past_key_values=outputs.past_key_values,
  1155. hidden_states=outputs.hidden_states,
  1156. attentions=outputs.attentions,
  1157. router_logits=outputs.router_logits,
  1158. )