modeling_gemma.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/gemma/modular_gemma.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_gemma.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import math
  23. from typing import List, Optional, Tuple, Union
  24. import torch
  25. import torch.utils.checkpoint
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import AttentionMaskConverter
  31. from ...modeling_flash_attention_utils import _flash_attention_forward
  32. from ...modeling_outputs import (
  33. BaseModelOutputWithPast,
  34. CausalLMOutputWithPast,
  35. SequenceClassifierOutputWithPast,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...utils import (
  40. add_code_sample_docstrings,
  41. add_start_docstrings,
  42. add_start_docstrings_to_model_forward,
  43. is_flash_attn_greater_or_equal_2_10,
  44. logging,
  45. replace_return_docstrings,
  46. )
  47. from .configuration_gemma import GemmaConfig
  48. _CHECKPOINT_FOR_DOC = "google/gemma-7b"
  49. class GemmaRMSNorm(nn.Module):
  50. def __init__(self, dim: int, eps: float = 1e-6):
  51. super().__init__()
  52. self.eps = eps
  53. self.weight = nn.Parameter(torch.zeros(dim))
  54. def _norm(self, x):
  55. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  56. def forward(self, x):
  57. output = self._norm(x.float())
  58. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  59. # See https://github.com/huggingface/transformers/pull/29402
  60. output = output * (1.0 + self.weight.float())
  61. return output.type_as(x)
  62. def extra_repr(self):
  63. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  64. logger = logging.get_logger(__name__)
  65. class GemmaRotaryEmbedding(nn.Module):
  66. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  67. super().__init__()
  68. self.dim = dim
  69. self.max_position_embeddings = max_position_embeddings
  70. self.base = base
  71. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  72. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  73. @torch.no_grad()
  74. def forward(self, x, position_ids, seq_len=None):
  75. # x: [bs, num_attention_heads, seq_len, head_size]
  76. self.inv_freq.to(x.device)
  77. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  78. position_ids_expanded = position_ids[:, None, :].float()
  79. # Force float32 since bfloat16 loses precision on long contexts
  80. # See https://github.com/huggingface/transformers/pull/29285
  81. device_type = x.device.type
  82. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  83. with torch.autocast(device_type=device_type, enabled=False):
  84. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  85. emb = torch.cat((freqs, freqs), dim=-1)
  86. cos = emb.cos()
  87. sin = emb.sin()
  88. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  89. class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
  90. """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
  91. def forward(self, x, position_ids):
  92. # difference to the original RoPE: a scaling factor is aplied to the position ids
  93. position_ids = position_ids.float() / self.scaling_factor
  94. cos, sin = super().forward(x, position_ids)
  95. return cos, sin
  96. class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
  97. """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
  98. def forward(self, x, position_ids):
  99. # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
  100. seq_len = torch.max(position_ids) + 1
  101. if seq_len > self.max_position_embeddings:
  102. base = self.base * (
  103. (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
  104. ) ** (self.dim / (self.dim - 2))
  105. inv_freq = 1.0 / (
  106. base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
  107. )
  108. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
  109. cos, sin = super().forward(x, position_ids)
  110. return cos, sin
  111. class GemmaMLP(nn.Module):
  112. def __init__(self, config):
  113. super().__init__()
  114. self.config = config
  115. self.hidden_size = config.hidden_size
  116. self.intermediate_size = config.intermediate_size
  117. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  118. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  119. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  120. if config.hidden_activation is None:
  121. logger.warning_once(
  122. "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
  123. "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
  124. "`config.hidden_activation` if you want to override this behaviour.\n"
  125. "See https://github.com/huggingface/transformers/pull/29402 for more details."
  126. )
  127. config.hidden_activation = "gelu_pytorch_tanh"
  128. hidden_activation = config.hidden_activation
  129. self.act_fn = ACT2FN[hidden_activation]
  130. def forward(self, x):
  131. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  132. def rotate_half(x):
  133. """Rotates half the hidden dims of the input."""
  134. x1 = x[..., : x.shape[-1] // 2]
  135. x2 = x[..., x.shape[-1] // 2 :]
  136. return torch.cat((-x2, x1), dim=-1)
  137. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  138. """Applies Rotary Position Embedding to the query and key tensors.
  139. Args:
  140. q (`torch.Tensor`): The query tensor.
  141. k (`torch.Tensor`): The key tensor.
  142. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  143. sin (`torch.Tensor`): The sine part of the rotary embedding.
  144. position_ids (`torch.Tensor`, *optional*):
  145. Deprecated and unused.
  146. unsqueeze_dim (`int`, *optional*, defaults to 1):
  147. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  148. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  149. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  150. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  151. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  152. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  153. Returns:
  154. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  155. """
  156. cos = cos.unsqueeze(unsqueeze_dim)
  157. sin = sin.unsqueeze(unsqueeze_dim)
  158. q_embed = (q * cos) + (rotate_half(q) * sin)
  159. k_embed = (k * cos) + (rotate_half(k) * sin)
  160. return q_embed, k_embed
  161. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  162. """
  163. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  164. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  165. """
  166. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  167. if n_rep == 1:
  168. return hidden_states
  169. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  170. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  171. class GemmaAttention(nn.Module):
  172. """Multi-headed attention from 'Attention Is All You Need' paper"""
  173. def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
  174. super().__init__()
  175. self.config = config
  176. self.layer_idx = layer_idx
  177. if layer_idx is None:
  178. logger.warning_once(
  179. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  180. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  181. "when creating this class."
  182. )
  183. self.attention_dropout = config.attention_dropout
  184. self.hidden_size = config.hidden_size
  185. self.num_heads = config.num_attention_heads
  186. self.head_dim = config.head_dim
  187. self.num_key_value_heads = config.num_key_value_heads
  188. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  189. self.max_position_embeddings = config.max_position_embeddings
  190. self.rope_theta = config.rope_theta
  191. self.is_causal = True
  192. self.scaling = 1 / math.sqrt(config.head_dim)
  193. if self.hidden_size % self.num_heads != 0:
  194. raise ValueError(
  195. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  196. f" and `num_heads`: {self.num_heads})."
  197. )
  198. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  199. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  200. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  201. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  202. self.rotary_emb = GemmaRotaryEmbedding(
  203. self.head_dim,
  204. max_position_embeddings=self.max_position_embeddings,
  205. base=self.rope_theta,
  206. )
  207. def forward(
  208. self,
  209. hidden_states: torch.Tensor,
  210. attention_mask: Optional[torch.Tensor] = None,
  211. position_ids: Optional[torch.LongTensor] = None,
  212. past_key_value: Optional[Cache] = None,
  213. output_attentions: bool = False,
  214. use_cache: bool = False,
  215. cache_position: Optional[torch.LongTensor] = None,
  216. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  217. bsz, q_len, _ = hidden_states.size()
  218. query_states = self.q_proj(hidden_states)
  219. key_states = self.k_proj(hidden_states)
  220. value_states = self.v_proj(hidden_states)
  221. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  222. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  223. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  224. cos, sin = self.rotary_emb(value_states, position_ids)
  225. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  226. if past_key_value is not None:
  227. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  228. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  229. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  230. key_states = repeat_kv(key_states, self.num_key_value_groups)
  231. value_states = repeat_kv(value_states, self.num_key_value_groups)
  232. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  233. if attention_mask is not None: # no matter the length, we just slice it
  234. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  235. attn_weights = attn_weights + causal_mask
  236. # upcast attention to fp32
  237. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  238. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  239. attn_output = torch.matmul(attn_weights, value_states)
  240. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  241. raise ValueError(
  242. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  243. f" {attn_output.size()}"
  244. )
  245. attn_output = attn_output.transpose(1, 2).contiguous()
  246. attn_output = attn_output.view(bsz, q_len, -1)
  247. attn_output = self.o_proj(attn_output)
  248. if not output_attentions:
  249. attn_weights = None
  250. return attn_output, attn_weights, past_key_value
  251. class GemmaSdpaAttention(GemmaAttention):
  252. """
  253. Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  254. `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  255. SDPA API.
  256. """
  257. # Adapted from GemmaAttention.forward
  258. def forward(
  259. self,
  260. hidden_states: torch.Tensor,
  261. attention_mask: Optional[torch.Tensor] = None,
  262. position_ids: Optional[torch.LongTensor] = None,
  263. past_key_value: Optional[Cache] = None,
  264. output_attentions: bool = False,
  265. use_cache: bool = False,
  266. cache_position: Optional[torch.LongTensor] = None,
  267. **kwargs,
  268. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  269. if output_attentions:
  270. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  271. logger.warning_once(
  272. "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  273. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  274. )
  275. return super().forward(
  276. hidden_states=hidden_states,
  277. attention_mask=attention_mask,
  278. position_ids=position_ids,
  279. past_key_value=past_key_value,
  280. output_attentions=output_attentions,
  281. use_cache=use_cache,
  282. cache_position=cache_position,
  283. )
  284. bsz, q_len, _ = hidden_states.size()
  285. query_states = self.q_proj(hidden_states)
  286. key_states = self.k_proj(hidden_states)
  287. value_states = self.v_proj(hidden_states)
  288. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  289. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  290. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  291. cos, sin = self.rotary_emb(value_states, position_ids)
  292. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  293. if past_key_value is not None:
  294. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  295. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  296. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  297. key_states = repeat_kv(key_states, self.num_key_value_groups)
  298. value_states = repeat_kv(value_states, self.num_key_value_groups)
  299. causal_mask = attention_mask
  300. if attention_mask is not None:
  301. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  302. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  303. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  304. if query_states.device.type == "cuda" and causal_mask is not None:
  305. query_states = query_states.contiguous()
  306. key_states = key_states.contiguous()
  307. value_states = value_states.contiguous()
  308. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  309. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  310. is_causal = True if causal_mask is None and q_len > 1 else False
  311. attn_output = torch.nn.functional.scaled_dot_product_attention(
  312. query_states,
  313. key_states,
  314. value_states,
  315. attn_mask=causal_mask,
  316. dropout_p=self.attention_dropout if self.training else 0.0,
  317. is_causal=is_causal,
  318. )
  319. attn_output = attn_output.transpose(1, 2).contiguous()
  320. attn_output = attn_output.view(bsz, q_len, -1)
  321. attn_output = self.o_proj(attn_output)
  322. return attn_output, None, past_key_value
  323. class GemmaFlashAttention2(GemmaAttention):
  324. """
  325. Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
  326. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  327. flash attention and deal with padding tokens in case the input contains any of them.
  328. """
  329. def __init__(self, *args, **kwargs):
  330. super().__init__(*args, **kwargs)
  331. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  332. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  333. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  334. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  335. def forward(
  336. self,
  337. hidden_states: torch.Tensor,
  338. attention_mask: Optional[torch.LongTensor] = None,
  339. position_ids: Optional[torch.LongTensor] = None,
  340. past_key_value: Optional[Cache] = None,
  341. output_attentions: bool = False,
  342. use_cache: bool = False,
  343. cache_position: Optional[torch.LongTensor] = None,
  344. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  345. if isinstance(past_key_value, StaticCache):
  346. raise ValueError(
  347. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  348. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  349. )
  350. output_attentions = False
  351. bsz, q_len, _ = hidden_states.size()
  352. query_states = self.q_proj(hidden_states)
  353. key_states = self.k_proj(hidden_states)
  354. value_states = self.v_proj(hidden_states)
  355. # Flash attention requires the input to have the shape
  356. # batch_size x seq_length x head_dim x hidden_dim
  357. # therefore we just need to keep the original shape
  358. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  359. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  360. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  361. cos, sin = self.rotary_emb(value_states, position_ids)
  362. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  363. if past_key_value is not None:
  364. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  365. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  366. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  367. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  368. # to be able to avoid many of these transpose/reshape/view.
  369. query_states = query_states.transpose(1, 2)
  370. key_states = key_states.transpose(1, 2)
  371. value_states = value_states.transpose(1, 2)
  372. dropout_rate = self.attention_dropout if self.training else 0.0
  373. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  374. # therefore the input hidden states gets silently casted in float32. Hence, we need
  375. # cast them back in the correct dtype just to be sure everything works as expected.
  376. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  377. # in fp32. (GemmaRMSNorm handles it correctly)
  378. input_dtype = query_states.dtype
  379. if input_dtype == torch.float32:
  380. if torch.is_autocast_enabled():
  381. target_dtype = torch.get_autocast_gpu_dtype()
  382. # Handle the case where the model is quantized
  383. elif hasattr(self.config, "_pre_quantization_dtype"):
  384. target_dtype = self.config._pre_quantization_dtype
  385. else:
  386. target_dtype = self.q_proj.weight.dtype
  387. logger.warning_once(
  388. f"The input hidden states seems to be silently casted in float32, this might be related to"
  389. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  390. f" {target_dtype}."
  391. )
  392. query_states = query_states.to(target_dtype)
  393. key_states = key_states.to(target_dtype)
  394. value_states = value_states.to(target_dtype)
  395. attn_output = _flash_attention_forward(
  396. query_states,
  397. key_states,
  398. value_states,
  399. attention_mask,
  400. q_len,
  401. position_ids=position_ids,
  402. dropout=dropout_rate,
  403. sliding_window=getattr(self, "sliding_window", None),
  404. is_causal=self.is_causal,
  405. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  406. )
  407. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  408. attn_output = self.o_proj(attn_output)
  409. if not output_attentions:
  410. attn_weights = None
  411. return attn_output, attn_weights, past_key_value
  412. GEMMA_ATTENTION_CLASSES = {
  413. "eager": GemmaAttention,
  414. "flash_attention_2": GemmaFlashAttention2,
  415. "sdpa": GemmaSdpaAttention,
  416. }
  417. class GemmaDecoderLayer(nn.Module):
  418. def __init__(self, config: GemmaConfig, layer_idx: int):
  419. super().__init__()
  420. self.hidden_size = config.hidden_size
  421. self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  422. self.mlp = GemmaMLP(config)
  423. self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  424. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  425. def forward(
  426. self,
  427. hidden_states: torch.Tensor,
  428. attention_mask: Optional[torch.Tensor] = None,
  429. position_ids: Optional[torch.LongTensor] = None,
  430. past_key_value: Optional[Cache] = None,
  431. output_attentions: Optional[bool] = False,
  432. use_cache: Optional[bool] = False,
  433. cache_position: Optional[torch.LongTensor] = None,
  434. **kwargs,
  435. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  436. """
  437. Args:
  438. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  439. attention_mask (`torch.FloatTensor`, *optional*):
  440. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  441. query_sequence_length, key_sequence_length)` if default attention is used.
  442. output_attentions (`bool`, *optional*):
  443. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  444. returned tensors for more detail.
  445. use_cache (`bool`, *optional*):
  446. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  447. (see `past_key_values`).
  448. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  449. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  450. Indices depicting the position of the input sequence tokens in the sequence
  451. kwargs (`dict`, *optional*):
  452. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  453. into the model
  454. """
  455. residual = hidden_states
  456. hidden_states = self.input_layernorm(hidden_states)
  457. # Self Attention
  458. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  459. hidden_states=hidden_states,
  460. attention_mask=attention_mask,
  461. position_ids=position_ids,
  462. past_key_value=past_key_value,
  463. output_attentions=output_attentions,
  464. use_cache=use_cache,
  465. cache_position=cache_position,
  466. **kwargs,
  467. )
  468. hidden_states = residual + hidden_states
  469. # Fully Connected
  470. residual = hidden_states
  471. hidden_states = self.post_attention_layernorm(hidden_states)
  472. hidden_states = self.mlp(hidden_states)
  473. hidden_states = residual + hidden_states
  474. outputs = (hidden_states,)
  475. if output_attentions:
  476. outputs += (self_attn_weights,)
  477. if use_cache:
  478. outputs += (present_key_value,)
  479. return outputs
  480. GEMMA_START_DOCSTRING = r"""
  481. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  482. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  483. etc.)
  484. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  485. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  486. and behavior.
  487. Parameters:
  488. config ([`GemmaConfig`]):
  489. Model configuration class with all the parameters of the model. Initializing with a config file does not
  490. load the weights associated with the model, only the configuration. Check out the
  491. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  492. """
  493. @add_start_docstrings(
  494. "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
  495. GEMMA_START_DOCSTRING,
  496. )
  497. class GemmaPreTrainedModel(PreTrainedModel):
  498. config_class = GemmaConfig
  499. base_model_prefix = "model"
  500. supports_gradient_checkpointing = True
  501. _no_split_modules = ["GemmaDecoderLayer"]
  502. _skip_keys_device_placement = ["past_key_values"]
  503. _supports_flash_attn_2 = True
  504. _supports_sdpa = True
  505. _supports_cache_class = True
  506. _supports_quantized_cache = True
  507. _supports_static_cache = True
  508. def _init_weights(self, module):
  509. std = self.config.initializer_range
  510. if isinstance(module, nn.Linear):
  511. module.weight.data.normal_(mean=0.0, std=std)
  512. if module.bias is not None:
  513. module.bias.data.zero_()
  514. elif isinstance(module, nn.Embedding):
  515. module.weight.data.normal_(mean=0.0, std=std)
  516. if module.padding_idx is not None:
  517. module.weight.data[module.padding_idx].zero_()
  518. _CONFIG_FOR_DOC = "GemmaConfig"
  519. GEMMA_INPUTS_DOCSTRING = r"""
  520. Args:
  521. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  522. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  523. it.
  524. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  525. [`PreTrainedTokenizer.__call__`] for details.
  526. [What are input IDs?](../glossary#input-ids)
  527. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  528. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  529. - 1 for tokens that are **not masked**,
  530. - 0 for tokens that are **masked**.
  531. [What are attention masks?](../glossary#attention-mask)
  532. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  533. [`PreTrainedTokenizer.__call__`] for details.
  534. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  535. `past_key_values`).
  536. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  537. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  538. information on the default strategy.
  539. - 1 indicates the head is **not masked**,
  540. - 0 indicates the head is **masked**.
  541. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  542. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  543. config.n_positions - 1]`.
  544. [What are position IDs?](../glossary#position-ids)
  545. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  546. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  547. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  548. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  549. Two formats are allowed:
  550. - a [`~cache_utils.Cache`] instance, see our
  551. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  552. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  553. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  554. cache format.
  555. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  556. legacy cache format will be returned.
  557. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  558. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  559. of shape `(batch_size, sequence_length)`.
  560. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  561. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  562. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  563. model's internal embedding lookup matrix.
  564. use_cache (`bool`, *optional*):
  565. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  566. `past_key_values`).
  567. output_attentions (`bool`, *optional*):
  568. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  569. tensors for more detail.
  570. output_hidden_states (`bool`, *optional*):
  571. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  572. more detail.
  573. return_dict (`bool`, *optional*):
  574. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  575. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  576. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  577. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  578. the complete sequence length.
  579. """
  580. @add_start_docstrings(
  581. "The bare Gemma Model outputting raw hidden-states without any specific head on top.",
  582. GEMMA_START_DOCSTRING,
  583. )
  584. class GemmaModel(GemmaPreTrainedModel):
  585. """
  586. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`]
  587. Args:
  588. config: GemmaConfig
  589. """
  590. def __init__(self, config: GemmaConfig):
  591. super().__init__(config)
  592. self.padding_idx = config.pad_token_id
  593. self.vocab_size = config.vocab_size
  594. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  595. self.layers = nn.ModuleList(
  596. [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  597. )
  598. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  599. self.gradient_checkpointing = False
  600. # Initialize weights and apply final processing
  601. self.post_init()
  602. def get_input_embeddings(self):
  603. return self.embed_tokens
  604. def set_input_embeddings(self, value):
  605. self.embed_tokens = value
  606. @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  607. def forward(
  608. self,
  609. input_ids: torch.LongTensor = None,
  610. attention_mask: Optional[torch.Tensor] = None,
  611. position_ids: Optional[torch.LongTensor] = None,
  612. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  613. inputs_embeds: Optional[torch.FloatTensor] = None,
  614. use_cache: Optional[bool] = None,
  615. output_attentions: Optional[bool] = None,
  616. output_hidden_states: Optional[bool] = None,
  617. return_dict: Optional[bool] = None,
  618. cache_position: Optional[torch.LongTensor] = None,
  619. ) -> Union[Tuple, BaseModelOutputWithPast]:
  620. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  621. output_hidden_states = (
  622. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  623. )
  624. use_cache = use_cache if use_cache is not None else self.config.use_cache
  625. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  626. if (input_ids is None) ^ (inputs_embeds is not None):
  627. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  628. if self.gradient_checkpointing and self.training and use_cache:
  629. logger.warning_once(
  630. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  631. )
  632. use_cache = False
  633. if inputs_embeds is None:
  634. inputs_embeds = self.embed_tokens(input_ids)
  635. # kept for BC (non `Cache` `past_key_values` inputs)
  636. return_legacy_cache = False # noqa: F841
  637. if use_cache and not isinstance(past_key_values, Cache):
  638. return_legacy_cache = True # noqa: F841
  639. if past_key_values is None:
  640. past_key_values = DynamicCache()
  641. else:
  642. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  643. logger.warning_once(
  644. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  645. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  646. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  647. )
  648. if cache_position is None:
  649. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  650. cache_position = torch.arange(
  651. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  652. )
  653. if position_ids is None:
  654. position_ids = cache_position.unsqueeze(0)
  655. causal_mask = self._update_causal_mask(
  656. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  657. )
  658. # embed positions
  659. hidden_states = inputs_embeds
  660. # normalized
  661. # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
  662. # See https://github.com/huggingface/transformers/pull/29402
  663. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  664. hidden_states = hidden_states * normalizer
  665. # decoder layers
  666. all_hidden_states = () if output_hidden_states else None
  667. all_self_attns = () if output_attentions else None
  668. next_decoder_cache = None
  669. for decoder_layer in self.layers:
  670. if output_hidden_states:
  671. all_hidden_states += (hidden_states,)
  672. if self.gradient_checkpointing and self.training:
  673. layer_outputs = self._gradient_checkpointing_func(
  674. decoder_layer.__call__,
  675. hidden_states,
  676. causal_mask,
  677. position_ids,
  678. past_key_values,
  679. output_attentions,
  680. use_cache,
  681. cache_position,
  682. )
  683. else:
  684. layer_outputs = decoder_layer(
  685. hidden_states,
  686. attention_mask=causal_mask,
  687. position_ids=position_ids,
  688. past_key_value=past_key_values,
  689. output_attentions=output_attentions,
  690. use_cache=use_cache,
  691. cache_position=cache_position,
  692. )
  693. hidden_states = layer_outputs[0]
  694. if use_cache:
  695. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  696. if output_attentions:
  697. all_self_attns += (layer_outputs[1],)
  698. hidden_states = self.norm(hidden_states)
  699. # add hidden states from the last decoder layer
  700. if output_hidden_states:
  701. all_hidden_states += (hidden_states,)
  702. next_cache = next_decoder_cache if use_cache else None
  703. if return_legacy_cache:
  704. next_cache = next_cache.to_legacy_cache()
  705. if not return_dict:
  706. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  707. return BaseModelOutputWithPast(
  708. last_hidden_state=hidden_states,
  709. past_key_values=next_cache,
  710. hidden_states=all_hidden_states,
  711. attentions=all_self_attns,
  712. )
  713. def _update_causal_mask(
  714. self,
  715. attention_mask: torch.Tensor,
  716. input_tensor: torch.Tensor,
  717. cache_position: torch.Tensor,
  718. past_key_values: Cache,
  719. output_attentions: bool,
  720. ):
  721. if self.config._attn_implementation == "flash_attention_2":
  722. if attention_mask is not None and 0.0 in attention_mask:
  723. return attention_mask
  724. return None
  725. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  726. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  727. # to infer the attention mask.
  728. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  729. using_static_cache = isinstance(past_key_values, StaticCache)
  730. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  731. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  732. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  733. attention_mask,
  734. inputs_embeds=input_tensor,
  735. past_key_values_length=past_seen_tokens,
  736. is_training=self.training,
  737. ):
  738. return None
  739. dtype, device = input_tensor.dtype, input_tensor.device
  740. sequence_length = input_tensor.shape[1]
  741. if using_static_cache:
  742. target_length = past_key_values.get_max_cache_shape()
  743. else:
  744. target_length = (
  745. attention_mask.shape[-1]
  746. if isinstance(attention_mask, torch.Tensor)
  747. else past_seen_tokens + sequence_length + 1
  748. )
  749. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  750. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  751. attention_mask,
  752. sequence_length=sequence_length,
  753. target_length=target_length,
  754. dtype=dtype,
  755. device=device,
  756. cache_position=cache_position,
  757. batch_size=input_tensor.shape[0],
  758. )
  759. if (
  760. self.config._attn_implementation == "sdpa"
  761. and attention_mask is not None
  762. and attention_mask.device.type == "cuda"
  763. and not output_attentions
  764. ):
  765. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  766. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  767. # Details: https://github.com/pytorch/pytorch/issues/110213
  768. min_dtype = torch.finfo(dtype).min
  769. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  770. return causal_mask
  771. @staticmethod
  772. def _prepare_4d_causal_attention_mask_with_cache_position(
  773. attention_mask: torch.Tensor,
  774. sequence_length: int,
  775. target_length: int,
  776. dtype: torch.dtype,
  777. device: torch.device,
  778. cache_position: torch.Tensor,
  779. batch_size: int,
  780. **kwargs,
  781. ):
  782. """
  783. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  784. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  785. Args:
  786. attention_mask (`torch.Tensor`):
  787. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  788. `(batch_size, 1, query_length, key_value_length)`.
  789. sequence_length (`int`):
  790. The sequence length being processed.
  791. target_length (`int`):
  792. The target length: when generating with static cache, the mask should be as long as the static cache,
  793. to account for the 0 padding, the part of the cache that is not filled yet.
  794. dtype (`torch.dtype`):
  795. The dtype to use for the 4D attention mask.
  796. device (`torch.device`):
  797. The device to plcae the 4D attention mask on.
  798. cache_position (`torch.Tensor`):
  799. Indices depicting the position of the input sequence tokens in the sequence.
  800. batch_size (`torch.Tensor`):
  801. Batch size.
  802. """
  803. if attention_mask is not None and attention_mask.dim() == 4:
  804. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  805. causal_mask = attention_mask
  806. else:
  807. min_dtype = torch.finfo(dtype).min
  808. causal_mask = torch.full(
  809. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  810. )
  811. if sequence_length != 1:
  812. causal_mask = torch.triu(causal_mask, diagonal=1)
  813. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  814. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  815. if attention_mask is not None:
  816. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  817. mask_length = attention_mask.shape[-1]
  818. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  819. padding_mask = padding_mask == 0
  820. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  821. padding_mask, min_dtype
  822. )
  823. return causal_mask
  824. class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
  825. _tied_weights_keys = ["lm_head.weight"]
  826. def __init__(self, config):
  827. super().__init__(config)
  828. self.model = GemmaModel(config)
  829. self.vocab_size = config.vocab_size
  830. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  831. # Initialize weights and apply final processing
  832. self.post_init()
  833. def get_input_embeddings(self):
  834. return self.model.embed_tokens
  835. def set_input_embeddings(self, value):
  836. self.model.embed_tokens = value
  837. def get_output_embeddings(self):
  838. return self.lm_head
  839. def set_output_embeddings(self, new_embeddings):
  840. self.lm_head = new_embeddings
  841. def set_decoder(self, decoder):
  842. self.model = decoder
  843. def get_decoder(self):
  844. return self.model
  845. @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  846. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  847. def forward(
  848. self,
  849. input_ids: torch.LongTensor = None,
  850. attention_mask: Optional[torch.Tensor] = None,
  851. position_ids: Optional[torch.LongTensor] = None,
  852. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  853. inputs_embeds: Optional[torch.FloatTensor] = None,
  854. labels: Optional[torch.LongTensor] = None,
  855. use_cache: Optional[bool] = None,
  856. output_attentions: Optional[bool] = None,
  857. output_hidden_states: Optional[bool] = None,
  858. return_dict: Optional[bool] = None,
  859. cache_position: Optional[torch.LongTensor] = None,
  860. num_logits_to_keep: int = 0,
  861. **loss_kwargs,
  862. ) -> Union[Tuple, CausalLMOutputWithPast]:
  863. r"""
  864. Args:
  865. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  866. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  867. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  868. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  869. num_logits_to_keep (`int`, *optional*):
  870. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  871. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  872. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  873. Returns:
  874. Example:
  875. ```python
  876. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  877. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  878. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  879. >>> prompt = "What is your favorite condiment?"
  880. >>> inputs = tokenizer(prompt, return_tensors="pt")
  881. >>> # Generate
  882. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  883. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  884. "What is your favorite condiment?"
  885. ```"""
  886. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  887. output_hidden_states = (
  888. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  889. )
  890. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  891. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  892. outputs = self.model(
  893. input_ids=input_ids,
  894. attention_mask=attention_mask,
  895. position_ids=position_ids,
  896. past_key_values=past_key_values,
  897. inputs_embeds=inputs_embeds,
  898. use_cache=use_cache,
  899. output_attentions=output_attentions,
  900. output_hidden_states=output_hidden_states,
  901. return_dict=return_dict,
  902. cache_position=cache_position,
  903. )
  904. hidden_states = outputs[0]
  905. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  906. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  907. loss = None
  908. if labels is not None:
  909. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  910. if not return_dict:
  911. output = (logits,) + outputs[1:]
  912. return (loss,) + output if loss is not None else output
  913. return CausalLMOutputWithPast(
  914. loss=loss,
  915. logits=logits,
  916. past_key_values=outputs.past_key_values,
  917. hidden_states=outputs.hidden_states,
  918. attentions=outputs.attentions,
  919. )
  920. @add_start_docstrings(
  921. """
  922. The Gemma Model transformer with a sequence classification head on top (linear layer).
  923. [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  924. (e.g. GPT-2) do.
  925. Since it does classification on the last token, it requires to know the position of the last token. If a
  926. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  927. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  928. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  929. each row of the batch).
  930. """,
  931. GEMMA_START_DOCSTRING,
  932. )
  933. class GemmaForSequenceClassification(GemmaPreTrainedModel):
  934. def __init__(self, config):
  935. super().__init__(config)
  936. self.num_labels = config.num_labels
  937. self.model = GemmaModel(config)
  938. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  939. # Initialize weights and apply final processing
  940. self.post_init()
  941. def get_input_embeddings(self):
  942. return self.model.embed_tokens
  943. def set_input_embeddings(self, value):
  944. self.model.embed_tokens = value
  945. @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  946. def forward(
  947. self,
  948. input_ids: Optional[torch.LongTensor] = None,
  949. attention_mask: Optional[torch.Tensor] = None,
  950. position_ids: Optional[torch.LongTensor] = None,
  951. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  952. inputs_embeds: Optional[torch.FloatTensor] = None,
  953. labels: Optional[torch.LongTensor] = None,
  954. use_cache: Optional[bool] = None,
  955. output_attentions: Optional[bool] = None,
  956. output_hidden_states: Optional[bool] = None,
  957. return_dict: Optional[bool] = None,
  958. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  959. r"""
  960. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  961. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  962. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  963. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  964. """
  965. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  966. transformer_outputs = self.model(
  967. input_ids,
  968. attention_mask=attention_mask,
  969. position_ids=position_ids,
  970. past_key_values=past_key_values,
  971. inputs_embeds=inputs_embeds,
  972. use_cache=use_cache,
  973. output_attentions=output_attentions,
  974. output_hidden_states=output_hidden_states,
  975. return_dict=return_dict,
  976. )
  977. hidden_states = transformer_outputs[0]
  978. logits = self.score(hidden_states)
  979. if input_ids is not None:
  980. batch_size = input_ids.shape[0]
  981. else:
  982. batch_size = inputs_embeds.shape[0]
  983. if self.config.pad_token_id is None and batch_size != 1:
  984. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  985. if self.config.pad_token_id is None:
  986. sequence_lengths = -1
  987. else:
  988. if input_ids is not None:
  989. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  990. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  991. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  992. sequence_lengths = sequence_lengths.to(logits.device)
  993. else:
  994. sequence_lengths = -1
  995. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  996. loss = None
  997. if labels is not None:
  998. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  999. if not return_dict:
  1000. output = (pooled_logits,) + transformer_outputs[1:]
  1001. return ((loss,) + output) if loss is not None else output
  1002. return SequenceClassifierOutputWithPast(
  1003. loss=loss,
  1004. logits=pooled_logits,
  1005. past_key_values=transformer_outputs.past_key_values,
  1006. hidden_states=transformer_outputs.hidden_states,
  1007. attentions=transformer_outputs.attentions,
  1008. )
  1009. @add_start_docstrings(
  1010. """
  1011. The Gemma Model transformer with a token classification head on top (a linear layer on top of the hidden-states
  1012. output) e.g. for Named-Entity-Recognition (NER) tasks.
  1013. """,
  1014. GEMMA_START_DOCSTRING,
  1015. )
  1016. class GemmaForTokenClassification(GemmaPreTrainedModel):
  1017. def __init__(self, config):
  1018. super().__init__(config)
  1019. self.num_labels = config.num_labels
  1020. self.model = GemmaModel(config)
  1021. if getattr(config, "classifier_dropout", None) is not None:
  1022. classifier_dropout = config.classifier_dropout
  1023. elif getattr(config, "hidden_dropout", None) is not None:
  1024. classifier_dropout = config.hidden_dropout
  1025. else:
  1026. classifier_dropout = 0.1
  1027. self.dropout = nn.Dropout(classifier_dropout)
  1028. self.score = nn.Linear(config.hidden_size, config.num_labels)
  1029. # Initialize weights and apply final processing
  1030. self.post_init()
  1031. def get_input_embeddings(self):
  1032. return self.model.embed_tokens
  1033. def set_input_embeddings(self, value):
  1034. self.model.embed_tokens = value
  1035. @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
  1036. @add_code_sample_docstrings(
  1037. checkpoint=_CHECKPOINT_FOR_DOC,
  1038. output_type=TokenClassifierOutput,
  1039. config_class=_CONFIG_FOR_DOC,
  1040. )
  1041. def forward(
  1042. self,
  1043. input_ids: Optional[torch.LongTensor] = None,
  1044. attention_mask: Optional[torch.Tensor] = None,
  1045. position_ids: Optional[torch.LongTensor] = None,
  1046. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1047. inputs_embeds: Optional[torch.FloatTensor] = None,
  1048. labels: Optional[torch.LongTensor] = None,
  1049. use_cache: Optional[bool] = None,
  1050. output_attentions: Optional[bool] = None,
  1051. output_hidden_states: Optional[bool] = None,
  1052. return_dict: Optional[bool] = None,
  1053. ) -> Union[Tuple, TokenClassifierOutput]:
  1054. r"""
  1055. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1056. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1057. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1058. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1059. """
  1060. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1061. outputs = self.model(
  1062. input_ids,
  1063. attention_mask=attention_mask,
  1064. position_ids=position_ids,
  1065. past_key_values=past_key_values,
  1066. inputs_embeds=inputs_embeds,
  1067. use_cache=use_cache,
  1068. output_attentions=output_attentions,
  1069. output_hidden_states=output_hidden_states,
  1070. return_dict=return_dict,
  1071. )
  1072. sequence_output = outputs[0]
  1073. sequence_output = self.dropout(sequence_output)
  1074. logits = self.score(sequence_output)
  1075. loss = None
  1076. if labels is not None:
  1077. loss = self.loss_function(logits, labels, self.config)
  1078. if not return_dict:
  1079. output = (logits,) + outputs[2:]
  1080. return ((loss,) + output) if loss is not None else output
  1081. return TokenClassifierOutput(
  1082. loss=loss,
  1083. logits=logits,
  1084. hidden_states=outputs.hidden_states,
  1085. attentions=outputs.attentions,
  1086. )
  1087. __all__ = ["GemmaModel", "GemmaForCausalLM", "GemmaForSequenceClassification", "GemmaForTokenClassification"]