modeling_starcoder2.py 64 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375
  1. # coding=utf-8
  2. # Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch Starcoder2 model."""
  21. import math
  22. from typing import List, Optional, Tuple, Union
  23. import torch
  24. import torch.utils.checkpoint
  25. from torch import nn
  26. from torch.nn import CrossEntropyLoss
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import AttentionMaskConverter
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPast,
  33. CausalLMOutputWithPast,
  34. SequenceClassifierOutputWithPast,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import (
  40. add_code_sample_docstrings,
  41. add_start_docstrings,
  42. add_start_docstrings_to_model_forward,
  43. is_flash_attn_2_available,
  44. is_flash_attn_greater_or_equal_2_10,
  45. logging,
  46. replace_return_docstrings,
  47. )
  48. from .configuration_starcoder2 import Starcoder2Config
  49. if is_flash_attn_2_available():
  50. from ...modeling_flash_attention_utils import _flash_attention_forward
  51. logger = logging.get_logger(__name__)
  52. _CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
  53. _CONFIG_FOR_DOC = "Starcoder2Config"
  54. # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2
  55. class Starcoder2RotaryEmbedding(nn.Module):
  56. def __init__(
  57. self,
  58. dim=None,
  59. max_position_embeddings=2048,
  60. base=10000,
  61. device=None,
  62. scaling_factor=1.0,
  63. rope_type="default",
  64. config: Optional[Starcoder2Config] = None,
  65. ):
  66. super().__init__()
  67. # TODO (joao): remove the `if` below, only used for BC
  68. self.rope_kwargs = {}
  69. if config is None:
  70. logger.warning_once(
  71. "`Starcoder2RotaryEmbedding` can now be fully parameterized by passing the model config through the "
  72. "`config` argument. All other arguments will be removed in v4.46"
  73. )
  74. self.rope_kwargs = {
  75. "rope_type": rope_type,
  76. "factor": scaling_factor,
  77. "dim": dim,
  78. "base": base,
  79. "max_position_embeddings": max_position_embeddings,
  80. }
  81. self.rope_type = rope_type
  82. self.max_seq_len_cached = max_position_embeddings
  83. self.original_max_seq_len = max_position_embeddings
  84. else:
  85. # BC: "rope_type" was originally "type"
  86. if config.rope_scaling is not None:
  87. self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
  88. else:
  89. self.rope_type = "default"
  90. self.max_seq_len_cached = config.max_position_embeddings
  91. self.original_max_seq_len = config.max_position_embeddings
  92. self.config = config
  93. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  94. inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
  95. self.register_buffer("inv_freq", inv_freq, persistent=False)
  96. self.original_inv_freq = self.inv_freq
  97. def _dynamic_frequency_update(self, position_ids, device):
  98. """
  99. dynamic RoPE layers should recompute `inv_freq` in the following situations:
  100. 1 - growing beyond the cached sequence length (allow scaling)
  101. 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
  102. """
  103. seq_len = torch.max(position_ids) + 1
  104. if seq_len > self.max_seq_len_cached: # growth
  105. inv_freq, self.attention_scaling = self.rope_init_fn(
  106. self.config, device, seq_len=seq_len, **self.rope_kwargs
  107. )
  108. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
  109. self.max_seq_len_cached = seq_len
  110. if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
  111. self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
  112. self.max_seq_len_cached = self.original_max_seq_len
  113. @torch.no_grad()
  114. def forward(self, x, position_ids):
  115. if "dynamic" in self.rope_type:
  116. self._dynamic_frequency_update(position_ids, device=x.device)
  117. # Core RoPE block
  118. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  119. position_ids_expanded = position_ids[:, None, :].float()
  120. # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
  121. device_type = x.device.type
  122. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  123. with torch.autocast(device_type=device_type, enabled=False):
  124. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  125. emb = torch.cat((freqs, freqs), dim=-1)
  126. cos = emb.cos()
  127. sin = emb.sin()
  128. # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
  129. cos = cos * self.attention_scaling
  130. sin = sin * self.attention_scaling
  131. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  132. # Copied from transformers.models.llama.modeling_llama.rotate_half
  133. def rotate_half(x):
  134. """Rotates half the hidden dims of the input."""
  135. x1 = x[..., : x.shape[-1] // 2]
  136. x2 = x[..., x.shape[-1] // 2 :]
  137. return torch.cat((-x2, x1), dim=-1)
  138. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  139. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  140. """Applies Rotary Position Embedding to the query and key tensors.
  141. Args:
  142. q (`torch.Tensor`): The query tensor.
  143. k (`torch.Tensor`): The key tensor.
  144. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  145. sin (`torch.Tensor`): The sine part of the rotary embedding.
  146. position_ids (`torch.Tensor`, *optional*):
  147. Deprecated and unused.
  148. unsqueeze_dim (`int`, *optional*, defaults to 1):
  149. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  150. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  151. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  152. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  153. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  154. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  155. Returns:
  156. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  157. """
  158. cos = cos.unsqueeze(unsqueeze_dim)
  159. sin = sin.unsqueeze(unsqueeze_dim)
  160. q_embed = (q * cos) + (rotate_half(q) * sin)
  161. k_embed = (k * cos) + (rotate_half(k) * sin)
  162. return q_embed, k_embed
  163. class Starcoder2MLP(nn.Module):
  164. def __init__(self, config: Starcoder2Config):
  165. super().__init__()
  166. embed_dim = config.hidden_size
  167. self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
  168. self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
  169. self.act = ACT2FN[config.hidden_act]
  170. self.residual_dropout = config.residual_dropout
  171. def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
  172. hidden_states = self.c_fc(hidden_states)
  173. hidden_states = self.act(hidden_states)
  174. hidden_states = self.c_proj(hidden_states)
  175. hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
  176. return hidden_states
  177. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  178. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  179. """
  180. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  181. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  182. """
  183. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  184. if n_rep == 1:
  185. return hidden_states
  186. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  187. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  188. class Starcoder2Attention(nn.Module):
  189. """
  190. Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
  191. and "Generating Long Sequences with Sparse Transformers".
  192. """
  193. def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
  194. super().__init__()
  195. self.config = config
  196. self.layer_idx = layer_idx
  197. if layer_idx is None:
  198. logger.warning_once(
  199. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  200. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  201. "when creating this class."
  202. )
  203. self.hidden_size = config.hidden_size
  204. self.num_heads = config.num_attention_heads
  205. self.head_dim = self.hidden_size // self.num_heads
  206. self.num_key_value_heads = config.num_key_value_heads
  207. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  208. self.rope_theta = config.rope_theta
  209. self.use_bias = config.use_bias
  210. self.is_causal = True
  211. self.attention_dropout = config.attention_dropout
  212. self.residual_dropout = config.residual_dropout
  213. if (self.head_dim * self.num_heads) != self.hidden_size:
  214. raise ValueError(
  215. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  216. f" and `num_heads`: {self.num_heads})."
  217. )
  218. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=self.use_bias)
  219. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
  220. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=self.use_bias)
  221. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.use_bias)
  222. self.rotary_emb = Starcoder2RotaryEmbedding(config=self.config)
  223. def forward(
  224. self,
  225. hidden_states: torch.Tensor,
  226. attention_mask: Optional[torch.Tensor] = None,
  227. position_ids: Optional[torch.LongTensor] = None,
  228. past_key_value: Optional[Cache] = None,
  229. output_attentions: bool = False,
  230. use_cache: bool = False,
  231. cache_position: Optional[torch.LongTensor] = None,
  232. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
  233. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  234. bsz, q_len, _ = hidden_states.size()
  235. query_states = self.q_proj(hidden_states)
  236. key_states = self.k_proj(hidden_states)
  237. value_states = self.v_proj(hidden_states)
  238. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  239. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  240. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  241. if position_embeddings is None:
  242. logger.warning_once(
  243. "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
  244. "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
  245. "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
  246. "removed and `position_embeddings` will be mandatory."
  247. )
  248. cos, sin = self.rotary_emb(value_states, position_ids)
  249. else:
  250. cos, sin = position_embeddings
  251. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  252. if past_key_value is not None:
  253. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  254. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  255. # repeat k/v heads if n_kv_heads < n_heads
  256. key_states = repeat_kv(key_states, self.num_key_value_groups)
  257. value_states = repeat_kv(value_states, self.num_key_value_groups)
  258. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  259. if attention_mask is not None: # no matter the length, we just slice it
  260. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  261. attn_weights += causal_mask
  262. # upcast attention to fp32
  263. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  264. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  265. attn_output = torch.matmul(attn_weights, value_states)
  266. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  267. raise ValueError(
  268. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  269. f" {attn_output.size()}"
  270. )
  271. attn_output = attn_output.transpose(1, 2).contiguous()
  272. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  273. attn_output = self.o_proj(attn_output)
  274. attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
  275. if not output_attentions:
  276. attn_weights = None
  277. return attn_output, attn_weights, past_key_value
  278. class Starcoder2FlashAttention2(Starcoder2Attention):
  279. """
  280. Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
  281. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  282. flash attention and deal with padding tokens in case the input contains any of them.
  283. """
  284. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  285. def __init__(self, *args, **kwargs):
  286. super().__init__(*args, **kwargs)
  287. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  288. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  289. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  290. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  291. # Ignore copy
  292. def forward(
  293. self,
  294. hidden_states: torch.Tensor,
  295. attention_mask: Optional[torch.Tensor] = None,
  296. position_ids: Optional[torch.LongTensor] = None,
  297. past_key_value: Optional[Cache] = None,
  298. output_attentions: bool = False,
  299. use_cache: bool = False,
  300. cache_position: Optional[torch.LongTensor] = None,
  301. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
  302. ):
  303. bsz, q_len, _ = hidden_states.size()
  304. query_states = self.q_proj(hidden_states)
  305. key_states = self.k_proj(hidden_states)
  306. value_states = self.v_proj(hidden_states)
  307. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  308. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  309. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  310. if position_embeddings is None:
  311. logger.warning_once(
  312. "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
  313. "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
  314. "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
  315. "removed and `position_embeddings` will be mandatory."
  316. )
  317. cos, sin = self.rotary_emb(value_states, position_ids)
  318. else:
  319. cos, sin = position_embeddings
  320. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  321. if past_key_value is not None:
  322. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  323. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  324. # repeat k/v heads if n_kv_heads < n_heads
  325. key_states = repeat_kv(key_states, self.num_key_value_groups)
  326. value_states = repeat_kv(value_states, self.num_key_value_groups)
  327. dropout_rate = 0.0 if not self.training else self.attention_dropout
  328. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  329. # therefore the input hidden states gets silently casted in float32. Hence, we need
  330. # cast them back in float16 just to be sure everything works as expected.
  331. input_dtype = query_states.dtype
  332. if input_dtype == torch.float32:
  333. if torch.is_autocast_enabled():
  334. target_dtype = torch.get_autocast_gpu_dtype()
  335. # Handle the case where the model is quantized
  336. elif hasattr(self.config, "_pre_quantization_dtype"):
  337. target_dtype = self.config._pre_quantization_dtype
  338. else:
  339. target_dtype = self.q_proj.weight.dtype
  340. logger.warning_once(
  341. f"The input hidden states seems to be silently casted in float32, this might be related to"
  342. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  343. f" {target_dtype}."
  344. )
  345. query_states = query_states.to(target_dtype)
  346. key_states = key_states.to(target_dtype)
  347. value_states = value_states.to(target_dtype)
  348. # Reashape to the expected shape for Flash Attention
  349. query_states = query_states.transpose(1, 2)
  350. key_states = key_states.transpose(1, 2)
  351. value_states = value_states.transpose(1, 2)
  352. attn_output = _flash_attention_forward(
  353. query_states,
  354. key_states,
  355. value_states,
  356. attention_mask,
  357. q_len,
  358. position_ids=position_ids,
  359. dropout=dropout_rate,
  360. sliding_window=getattr(self.config, "sliding_window", None),
  361. is_causal=self.is_causal,
  362. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  363. )
  364. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  365. attn_output = self.o_proj(attn_output)
  366. attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
  367. if not output_attentions:
  368. attn_weights = None
  369. return attn_output, attn_weights, past_key_value
  370. # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
  371. class Starcoder2SdpaAttention(Starcoder2Attention):
  372. """
  373. Starcoder2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  374. `Starcoder2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  375. SDPA API.
  376. """
  377. # Ignore copy
  378. def forward(
  379. self,
  380. hidden_states: torch.Tensor,
  381. attention_mask: Optional[torch.Tensor] = None,
  382. position_ids: Optional[torch.LongTensor] = None,
  383. past_key_value: Optional[Cache] = None,
  384. output_attentions: bool = False,
  385. use_cache: bool = False,
  386. cache_position: Optional[torch.LongTensor] = None,
  387. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
  388. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  389. if output_attentions:
  390. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  391. logger.warning_once(
  392. "Starcoder2Model is using Starcoder2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  393. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  394. )
  395. return super().forward(
  396. hidden_states=hidden_states,
  397. attention_mask=attention_mask,
  398. position_ids=position_ids,
  399. past_key_value=past_key_value,
  400. output_attentions=output_attentions,
  401. use_cache=use_cache,
  402. )
  403. bsz, q_len, _ = hidden_states.size()
  404. query_states = self.q_proj(hidden_states)
  405. key_states = self.k_proj(hidden_states)
  406. value_states = self.v_proj(hidden_states)
  407. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  408. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  409. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  410. if position_embeddings is None:
  411. logger.warning_once(
  412. "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
  413. "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
  414. "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
  415. "removed and `position_embeddings` will be mandatory."
  416. )
  417. cos, sin = self.rotary_emb(value_states, position_ids)
  418. else:
  419. cos, sin = position_embeddings
  420. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  421. if past_key_value is not None:
  422. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  423. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  424. key_states = repeat_kv(key_states, self.num_key_value_groups)
  425. value_states = repeat_kv(value_states, self.num_key_value_groups)
  426. causal_mask = attention_mask
  427. if attention_mask is not None: # no matter the length, we just slice it
  428. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  429. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  430. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  431. if query_states.device.type == "cuda" and attention_mask is not None:
  432. query_states = query_states.contiguous()
  433. key_states = key_states.contiguous()
  434. value_states = value_states.contiguous()
  435. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  436. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  437. # # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  438. is_causal = True if causal_mask is None and q_len > 1 else False
  439. attn_output = torch.nn.functional.scaled_dot_product_attention(
  440. query_states,
  441. key_states,
  442. value_states,
  443. attn_mask=causal_mask,
  444. dropout_p=self.attention_dropout if self.training else 0.0,
  445. is_causal=is_causal,
  446. )
  447. attn_output = attn_output.transpose(1, 2).contiguous()
  448. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  449. attn_output = self.o_proj(attn_output)
  450. # The difference with Mistral is that here it uses dropout
  451. attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
  452. return attn_output, None, past_key_value
  453. STARCODER2_ATTENTION_CLASSES = {
  454. "eager": Starcoder2Attention,
  455. "flash_attention_2": Starcoder2FlashAttention2,
  456. "sdpa": Starcoder2SdpaAttention,
  457. }
  458. class Starcoder2DecoderLayer(nn.Module):
  459. def __init__(self, config: Starcoder2Config, layer_idx: int):
  460. super().__init__()
  461. self.hidden_size = config.hidden_size
  462. self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
  463. self.mlp = Starcoder2MLP(config)
  464. self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  465. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  466. # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer.forward
  467. def forward(
  468. self,
  469. hidden_states: torch.Tensor,
  470. attention_mask: Optional[torch.Tensor] = None,
  471. position_ids: Optional[torch.LongTensor] = None,
  472. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  473. output_attentions: Optional[bool] = False,
  474. use_cache: Optional[bool] = False,
  475. cache_position: Optional[torch.LongTensor] = None,
  476. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
  477. **kwargs,
  478. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  479. """
  480. Args:
  481. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  482. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  483. `(batch, sequence_length)` where padding elements are indicated by 0.
  484. output_attentions (`bool`, *optional*):
  485. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  486. returned tensors for more detail.
  487. use_cache (`bool`, *optional*):
  488. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  489. (see `past_key_values`).
  490. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  491. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  492. Indices depicting the position of the input sequence tokens in the sequence.
  493. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  494. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  495. with `head_dim` being the embedding dimension of each attention head.
  496. kwargs (`dict`, *optional*):
  497. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  498. into the model
  499. """
  500. residual = hidden_states
  501. hidden_states = self.input_layernorm(hidden_states)
  502. # Self Attention
  503. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  504. hidden_states=hidden_states,
  505. attention_mask=attention_mask,
  506. position_ids=position_ids,
  507. past_key_value=past_key_value,
  508. output_attentions=output_attentions,
  509. use_cache=use_cache,
  510. cache_position=cache_position,
  511. position_embeddings=position_embeddings,
  512. )
  513. hidden_states = residual + hidden_states
  514. # Fully Connected
  515. residual = hidden_states
  516. hidden_states = self.post_attention_layernorm(hidden_states)
  517. hidden_states = self.mlp(hidden_states)
  518. hidden_states = residual + hidden_states
  519. outputs = (hidden_states,)
  520. if output_attentions:
  521. outputs += (self_attn_weights,)
  522. if use_cache:
  523. outputs += (present_key_value,)
  524. return outputs
  525. STARCODER2_START_DOCSTRING = r"""
  526. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  527. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  528. etc.)
  529. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  530. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  531. and behavior.
  532. Parameters:
  533. config ([`Starcoder2Config`]):
  534. Model configuration class with all the parameters of the model. Initializing with a config file does not
  535. load the weights associated with the model, only the configuration. Check out the
  536. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  537. """
  538. @add_start_docstrings(
  539. "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
  540. STARCODER2_START_DOCSTRING,
  541. )
  542. # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2PreTrainedModel with Qwen2->Starcoder2
  543. class Starcoder2PreTrainedModel(PreTrainedModel):
  544. config_class = Starcoder2Config
  545. base_model_prefix = "model"
  546. supports_gradient_checkpointing = True
  547. _no_split_modules = ["Starcoder2DecoderLayer"]
  548. _skip_keys_device_placement = "past_key_values"
  549. _supports_flash_attn_2 = True
  550. _supports_sdpa = True
  551. _supports_cache_class = True
  552. _supports_quantized_cache = True
  553. _supports_static_cache = True
  554. def _init_weights(self, module):
  555. std = self.config.initializer_range
  556. if isinstance(module, nn.Linear):
  557. module.weight.data.normal_(mean=0.0, std=std)
  558. if module.bias is not None:
  559. module.bias.data.zero_()
  560. elif isinstance(module, nn.Embedding):
  561. module.weight.data.normal_(mean=0.0, std=std)
  562. if module.padding_idx is not None:
  563. module.weight.data[module.padding_idx].zero_()
  564. STARCODER2_INPUTS_DOCSTRING = r"""
  565. Args:
  566. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  567. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  568. it.
  569. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  570. [`PreTrainedTokenizer.__call__`] for details.
  571. [What are input IDs?](../glossary#input-ids)
  572. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  573. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  574. - 1 for tokens that are **not masked**,
  575. - 0 for tokens that are **masked**.
  576. [What are attention masks?](../glossary#attention-mask)
  577. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  578. [`PreTrainedTokenizer.__call__`] for details.
  579. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  580. `past_key_values`).
  581. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  582. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  583. information on the default strategy.
  584. - 1 indicates the head is **not masked**,
  585. - 0 indicates the head is **masked**.
  586. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  587. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  588. config.n_positions - 1]`.
  589. [What are position IDs?](../glossary#position-ids)
  590. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  591. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  592. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  593. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  594. Two formats are allowed:
  595. - a [`~cache_utils.Cache`] instance, see our
  596. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  597. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  598. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  599. cache format.
  600. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  601. legacy cache format will be returned.
  602. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  603. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  604. of shape `(batch_size, sequence_length)`.
  605. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  606. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  607. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  608. model's internal embedding lookup matrix.
  609. use_cache (`bool`, *optional*):
  610. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  611. `past_key_values`).
  612. output_attentions (`bool`, *optional*):
  613. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  614. tensors for more detail.
  615. output_hidden_states (`bool`, *optional*):
  616. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  617. more detail.
  618. return_dict (`bool`, *optional*):
  619. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  620. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  621. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  622. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  623. the complete sequence length.
  624. """
  625. @add_start_docstrings(
  626. "The bare Starcoder2 Model outputting raw hidden-states without any specific head on top.",
  627. STARCODER2_START_DOCSTRING,
  628. )
  629. class Starcoder2Model(Starcoder2PreTrainedModel):
  630. """
  631. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Starcoder2DecoderLayer`]
  632. Args:
  633. config: Starcoder2Config
  634. """
  635. def __init__(self, config: Starcoder2Config):
  636. super().__init__(config)
  637. self.padding_idx = config.pad_token_id
  638. self.vocab_size = config.vocab_size
  639. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  640. self.embedding_dropout = config.embedding_dropout
  641. self.layers = nn.ModuleList(
  642. [Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  643. )
  644. self._attn_implementation = config._attn_implementation
  645. self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
  646. self.rotary_emb = Starcoder2RotaryEmbedding(config=config)
  647. self.gradient_checkpointing = False
  648. # Initialize weights and apply final processing
  649. self.post_init()
  650. def get_input_embeddings(self):
  651. return self.embed_tokens
  652. def set_input_embeddings(self, value):
  653. self.embed_tokens = value
  654. @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
  655. def forward(
  656. self,
  657. input_ids: torch.LongTensor = None,
  658. attention_mask: Optional[torch.Tensor] = None,
  659. position_ids: Optional[torch.LongTensor] = None,
  660. past_key_values: Optional[List[torch.FloatTensor]] = None,
  661. inputs_embeds: Optional[torch.FloatTensor] = None,
  662. use_cache: Optional[bool] = None,
  663. output_attentions: Optional[bool] = None,
  664. output_hidden_states: Optional[bool] = None,
  665. return_dict: Optional[bool] = None,
  666. cache_position: Optional[torch.LongTensor] = None,
  667. ) -> Union[Tuple, BaseModelOutputWithPast]:
  668. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  669. output_hidden_states = (
  670. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  671. )
  672. use_cache = use_cache if use_cache is not None else self.config.use_cache
  673. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  674. if (input_ids is None) ^ (inputs_embeds is not None):
  675. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  676. if self.gradient_checkpointing and self.training:
  677. if use_cache:
  678. logger.warning_once(
  679. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  680. )
  681. use_cache = False
  682. # kept for BC (non `Cache` `past_key_values` inputs)
  683. return_legacy_cache = False
  684. if use_cache and not isinstance(past_key_values, Cache):
  685. return_legacy_cache = True
  686. if past_key_values is None:
  687. past_key_values = DynamicCache()
  688. else:
  689. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  690. logger.warning_once(
  691. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  692. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  693. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  694. )
  695. if inputs_embeds is None:
  696. inputs_embeds = self.embed_tokens(input_ids)
  697. if cache_position is None:
  698. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  699. cache_position = torch.arange(
  700. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  701. )
  702. if position_ids is None:
  703. position_ids = cache_position.unsqueeze(0)
  704. causal_mask = self._update_causal_mask(
  705. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  706. )
  707. hidden_states = inputs_embeds
  708. hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)
  709. # create position embeddings to be shared across the decoder layers
  710. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  711. # decoder layers
  712. all_hidden_states = () if output_hidden_states else None
  713. all_self_attns = () if output_attentions else None
  714. next_decoder_cache = None
  715. for decoder_layer in self.layers:
  716. if output_hidden_states:
  717. all_hidden_states += (hidden_states,)
  718. if self.gradient_checkpointing and self.training:
  719. layer_outputs = self._gradient_checkpointing_func(
  720. decoder_layer.__call__,
  721. hidden_states,
  722. causal_mask,
  723. position_ids,
  724. past_key_values,
  725. output_attentions,
  726. use_cache,
  727. cache_position,
  728. position_embeddings,
  729. )
  730. else:
  731. layer_outputs = decoder_layer(
  732. hidden_states,
  733. attention_mask=causal_mask,
  734. position_ids=position_ids,
  735. past_key_value=past_key_values,
  736. output_attentions=output_attentions,
  737. use_cache=use_cache,
  738. cache_position=cache_position,
  739. position_embeddings=position_embeddings,
  740. )
  741. hidden_states = layer_outputs[0]
  742. if use_cache:
  743. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  744. if output_attentions:
  745. all_self_attns += (layer_outputs[1],)
  746. hidden_states = self.norm(hidden_states)
  747. # add hidden states from the last decoder layer
  748. if output_hidden_states:
  749. all_hidden_states += (hidden_states,)
  750. next_cache = next_decoder_cache if use_cache else None
  751. if return_legacy_cache:
  752. next_cache = next_cache.to_legacy_cache()
  753. if not return_dict:
  754. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  755. return BaseModelOutputWithPast(
  756. last_hidden_state=hidden_states,
  757. past_key_values=next_cache,
  758. hidden_states=all_hidden_states,
  759. attentions=all_self_attns,
  760. )
  761. # Copied from transformers.models.phi3.modeling_phi3.Phi3Model._update_causal_mask
  762. def _update_causal_mask(
  763. self,
  764. attention_mask: torch.Tensor,
  765. input_tensor: torch.Tensor,
  766. cache_position: torch.Tensor,
  767. past_key_values: Cache,
  768. output_attentions: bool,
  769. ):
  770. if self.config._attn_implementation == "flash_attention_2":
  771. if attention_mask is not None and 0.0 in attention_mask:
  772. return attention_mask
  773. return None
  774. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  775. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  776. # to infer the attention mask.
  777. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  778. using_static_cache = isinstance(past_key_values, StaticCache)
  779. using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
  780. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  781. if (
  782. self.config._attn_implementation == "sdpa"
  783. and not (using_static_cache or using_sliding_window_cache)
  784. and not output_attentions
  785. ):
  786. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  787. attention_mask,
  788. inputs_embeds=input_tensor,
  789. past_key_values_length=past_seen_tokens,
  790. sliding_window=self.config.sliding_window,
  791. is_training=self.training,
  792. ):
  793. return None
  794. dtype, device = input_tensor.dtype, input_tensor.device
  795. min_dtype = torch.finfo(dtype).min
  796. sequence_length = input_tensor.shape[1]
  797. # SlidingWindowCache or StaticCache
  798. if using_sliding_window_cache or using_static_cache:
  799. target_length = past_key_values.get_max_cache_shape()
  800. # DynamicCache or no cache
  801. else:
  802. target_length = (
  803. attention_mask.shape[-1]
  804. if isinstance(attention_mask, torch.Tensor)
  805. else past_seen_tokens + sequence_length + 1
  806. )
  807. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  808. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  809. attention_mask,
  810. sequence_length=sequence_length,
  811. target_length=target_length,
  812. dtype=dtype,
  813. device=device,
  814. cache_position=cache_position,
  815. batch_size=input_tensor.shape[0],
  816. config=self.config,
  817. past_key_values=past_key_values,
  818. )
  819. if (
  820. self.config._attn_implementation == "sdpa"
  821. and attention_mask is not None
  822. and attention_mask.device.type == "cuda"
  823. and not output_attentions
  824. ):
  825. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  826. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  827. # Details: https://github.com/pytorch/pytorch/issues/110213
  828. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  829. return causal_mask
  830. @staticmethod
  831. # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Starcoder2
  832. def _prepare_4d_causal_attention_mask_with_cache_position(
  833. attention_mask: torch.Tensor,
  834. sequence_length: int,
  835. target_length: int,
  836. dtype: torch.dtype,
  837. device: torch.device,
  838. cache_position: torch.Tensor,
  839. batch_size: int,
  840. config: Starcoder2Config,
  841. past_key_values: Cache,
  842. ):
  843. """
  844. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  845. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  846. Args:
  847. attention_mask (`torch.Tensor`):
  848. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  849. sequence_length (`int`):
  850. The sequence length being processed.
  851. target_length (`int`):
  852. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  853. dtype (`torch.dtype`):
  854. The dtype to use for the 4D attention mask.
  855. device (`torch.device`):
  856. The device to plcae the 4D attention mask on.
  857. cache_position (`torch.Tensor`):
  858. Indices depicting the position of the input sequence tokens in the sequence.
  859. batch_size (`torch.Tensor`):
  860. Batch size.
  861. config (`Starcoder2Config`):
  862. The model's configuration class
  863. past_key_values (`Cache`):
  864. The cache class that is being used currently to generate
  865. """
  866. if attention_mask is not None and attention_mask.dim() == 4:
  867. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  868. causal_mask = attention_mask
  869. else:
  870. min_dtype = torch.finfo(dtype).min
  871. causal_mask = torch.full(
  872. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  873. )
  874. diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  875. if config.sliding_window is not None:
  876. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  877. # the check is needed to verify is current checkpoint was trained with sliding window or not
  878. if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
  879. sliding_attend_mask = torch.arange(target_length, device=device) <= (
  880. cache_position.reshape(-1, 1) - config.sliding_window
  881. )
  882. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  883. causal_mask *= diagonal_attend_mask
  884. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  885. if attention_mask is not None:
  886. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  887. if attention_mask.shape[-1] > target_length:
  888. attention_mask = attention_mask[:, :target_length]
  889. mask_length = attention_mask.shape[-1]
  890. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  891. padding_mask = padding_mask == 0
  892. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  893. padding_mask, min_dtype
  894. )
  895. return causal_mask
  896. # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2
  897. class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin):
  898. _tied_weights_keys = ["lm_head.weight"]
  899. def __init__(self, config):
  900. super().__init__(config)
  901. self.model = Starcoder2Model(config)
  902. self.vocab_size = config.vocab_size
  903. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  904. # Initialize weights and apply final processing
  905. self.post_init()
  906. def get_input_embeddings(self):
  907. return self.model.embed_tokens
  908. def set_input_embeddings(self, value):
  909. self.model.embed_tokens = value
  910. def get_output_embeddings(self):
  911. return self.lm_head
  912. def set_output_embeddings(self, new_embeddings):
  913. self.lm_head = new_embeddings
  914. def set_decoder(self, decoder):
  915. self.model = decoder
  916. def get_decoder(self):
  917. return self.model
  918. @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
  919. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  920. # Ignore copy
  921. def forward(
  922. self,
  923. input_ids: torch.LongTensor = None,
  924. attention_mask: Optional[torch.Tensor] = None,
  925. position_ids: Optional[torch.LongTensor] = None,
  926. past_key_values: Optional[List[torch.FloatTensor]] = None,
  927. inputs_embeds: Optional[torch.FloatTensor] = None,
  928. labels: Optional[torch.LongTensor] = None,
  929. use_cache: Optional[bool] = None,
  930. output_attentions: Optional[bool] = None,
  931. output_hidden_states: Optional[bool] = None,
  932. return_dict: Optional[bool] = None,
  933. cache_position: Optional[torch.LongTensor] = None,
  934. num_logits_to_keep: int = 0,
  935. ) -> Union[Tuple, CausalLMOutputWithPast]:
  936. r"""
  937. Args:
  938. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  939. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  940. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  941. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  942. num_logits_to_keep (`int`, *optional*):
  943. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  944. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  945. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  946. Returns:
  947. Example:
  948. ```python
  949. >>> from transformers import AutoTokenizer, Starcoder2ForCausalLM
  950. >>> model = Starcoder2ForCausalLM.from_pretrained("bigcode/starcoder2-7b")
  951. >>> tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-7b")
  952. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  953. >>> inputs = tokenizer(prompt, return_tensors="pt")
  954. >>> # Generate
  955. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  956. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  957. "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
  958. ```"""
  959. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  960. output_hidden_states = (
  961. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  962. )
  963. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  964. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  965. outputs = self.model(
  966. input_ids=input_ids,
  967. attention_mask=attention_mask,
  968. position_ids=position_ids,
  969. past_key_values=past_key_values,
  970. inputs_embeds=inputs_embeds,
  971. use_cache=use_cache,
  972. output_attentions=output_attentions,
  973. output_hidden_states=output_hidden_states,
  974. return_dict=return_dict,
  975. cache_position=cache_position,
  976. )
  977. hidden_states = outputs[0]
  978. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  979. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  980. loss = None
  981. if labels is not None:
  982. # Upcast to float if we need to compute the loss to avoid potential precision issues
  983. logits = logits.float()
  984. # Shift so that tokens < n predict n
  985. shift_logits = logits[..., :-1, :].contiguous()
  986. shift_labels = labels[..., 1:].contiguous()
  987. # Flatten the tokens
  988. shift_logits = shift_logits.view(-1, self.config.vocab_size)
  989. shift_labels = shift_labels.view(-1)
  990. # Ensure tensors are on the same device
  991. shift_labels = shift_labels.to(shift_logits.device)
  992. loss_fct = CrossEntropyLoss()
  993. loss = loss_fct(shift_logits, shift_labels)
  994. if not return_dict:
  995. output = (logits,) + outputs[1:]
  996. return (loss,) + output if loss is not None else output
  997. return CausalLMOutputWithPast(
  998. loss=loss,
  999. logits=logits,
  1000. past_key_values=outputs.past_key_values,
  1001. hidden_states=outputs.hidden_states,
  1002. attentions=outputs.attentions,
  1003. )
  1004. @add_start_docstrings(
  1005. """
  1006. The Starcoder2 Model transformer with a sequence classification head on top (linear layer).
  1007. [`Starcoder2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1008. (e.g. GPT-2) do.
  1009. Since it does classification on the last token, it requires to know the position of the last token. If a
  1010. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1011. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1012. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1013. each row of the batch).
  1014. """,
  1015. STARCODER2_START_DOCSTRING,
  1016. )
  1017. # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Starcoder2, LLAMA->STARCODER2
  1018. class Starcoder2ForSequenceClassification(Starcoder2PreTrainedModel):
  1019. def __init__(self, config):
  1020. super().__init__(config)
  1021. self.num_labels = config.num_labels
  1022. self.model = Starcoder2Model(config)
  1023. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1024. # Initialize weights and apply final processing
  1025. self.post_init()
  1026. def get_input_embeddings(self):
  1027. return self.model.embed_tokens
  1028. def set_input_embeddings(self, value):
  1029. self.model.embed_tokens = value
  1030. @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
  1031. def forward(
  1032. self,
  1033. input_ids: Optional[torch.LongTensor] = None,
  1034. attention_mask: Optional[torch.Tensor] = None,
  1035. position_ids: Optional[torch.LongTensor] = None,
  1036. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1037. inputs_embeds: Optional[torch.FloatTensor] = None,
  1038. labels: Optional[torch.LongTensor] = None,
  1039. use_cache: Optional[bool] = None,
  1040. output_attentions: Optional[bool] = None,
  1041. output_hidden_states: Optional[bool] = None,
  1042. return_dict: Optional[bool] = None,
  1043. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1044. r"""
  1045. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1046. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1047. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1048. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1049. """
  1050. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1051. transformer_outputs = self.model(
  1052. input_ids,
  1053. attention_mask=attention_mask,
  1054. position_ids=position_ids,
  1055. past_key_values=past_key_values,
  1056. inputs_embeds=inputs_embeds,
  1057. use_cache=use_cache,
  1058. output_attentions=output_attentions,
  1059. output_hidden_states=output_hidden_states,
  1060. return_dict=return_dict,
  1061. )
  1062. hidden_states = transformer_outputs[0]
  1063. logits = self.score(hidden_states)
  1064. if input_ids is not None:
  1065. batch_size = input_ids.shape[0]
  1066. else:
  1067. batch_size = inputs_embeds.shape[0]
  1068. if self.config.pad_token_id is None and batch_size != 1:
  1069. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1070. if self.config.pad_token_id is None:
  1071. sequence_lengths = -1
  1072. else:
  1073. if input_ids is not None:
  1074. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1075. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1076. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1077. sequence_lengths = sequence_lengths.to(logits.device)
  1078. else:
  1079. sequence_lengths = -1
  1080. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1081. loss = None
  1082. if labels is not None:
  1083. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1084. if not return_dict:
  1085. output = (pooled_logits,) + transformer_outputs[1:]
  1086. return ((loss,) + output) if loss is not None else output
  1087. return SequenceClassifierOutputWithPast(
  1088. loss=loss,
  1089. logits=pooled_logits,
  1090. past_key_values=transformer_outputs.past_key_values,
  1091. hidden_states=transformer_outputs.hidden_states,
  1092. attentions=transformer_outputs.attentions,
  1093. )
  1094. @add_start_docstrings(
  1095. """
  1096. The Starcoder2 Model transformer with a token classification head on top (a linear layer on top of the hidden-states
  1097. output) e.g. for Named-Entity-Recognition (NER) tasks.
  1098. """,
  1099. STARCODER2_START_DOCSTRING,
  1100. )
  1101. # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Starcoder2, LLAMA->STARCODER2
  1102. class Starcoder2ForTokenClassification(Starcoder2PreTrainedModel):
  1103. def __init__(self, config):
  1104. super().__init__(config)
  1105. self.num_labels = config.num_labels
  1106. self.model = Starcoder2Model(config)
  1107. if getattr(config, "classifier_dropout", None) is not None:
  1108. classifier_dropout = config.classifier_dropout
  1109. elif getattr(config, "hidden_dropout", None) is not None:
  1110. classifier_dropout = config.hidden_dropout
  1111. else:
  1112. classifier_dropout = 0.1
  1113. self.dropout = nn.Dropout(classifier_dropout)
  1114. self.score = nn.Linear(config.hidden_size, config.num_labels)
  1115. # Initialize weights and apply final processing
  1116. self.post_init()
  1117. def get_input_embeddings(self):
  1118. return self.model.embed_tokens
  1119. def set_input_embeddings(self, value):
  1120. self.model.embed_tokens = value
  1121. @add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
  1122. @add_code_sample_docstrings(
  1123. checkpoint=_CHECKPOINT_FOR_DOC,
  1124. output_type=TokenClassifierOutput,
  1125. config_class=_CONFIG_FOR_DOC,
  1126. )
  1127. def forward(
  1128. self,
  1129. input_ids: Optional[torch.LongTensor] = None,
  1130. attention_mask: Optional[torch.Tensor] = None,
  1131. position_ids: Optional[torch.LongTensor] = None,
  1132. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1133. inputs_embeds: Optional[torch.FloatTensor] = None,
  1134. labels: Optional[torch.LongTensor] = None,
  1135. use_cache: Optional[bool] = None,
  1136. output_attentions: Optional[bool] = None,
  1137. output_hidden_states: Optional[bool] = None,
  1138. return_dict: Optional[bool] = None,
  1139. ) -> Union[Tuple, TokenClassifierOutput]:
  1140. r"""
  1141. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1142. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1143. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1144. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1145. """
  1146. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1147. outputs = self.model(
  1148. input_ids,
  1149. attention_mask=attention_mask,
  1150. position_ids=position_ids,
  1151. past_key_values=past_key_values,
  1152. inputs_embeds=inputs_embeds,
  1153. use_cache=use_cache,
  1154. output_attentions=output_attentions,
  1155. output_hidden_states=output_hidden_states,
  1156. return_dict=return_dict,
  1157. )
  1158. sequence_output = outputs[0]
  1159. sequence_output = self.dropout(sequence_output)
  1160. logits = self.score(sequence_output)
  1161. loss = None
  1162. if labels is not None:
  1163. loss = self.loss_function(logits, labels, self.config)
  1164. if not return_dict:
  1165. output = (logits,) + outputs[2:]
  1166. return ((loss,) + output) if loss is not None else output
  1167. return TokenClassifierOutput(
  1168. loss=loss,
  1169. logits=logits,
  1170. hidden_states=outputs.hidden_states,
  1171. attentions=outputs.attentions,
  1172. )