modular_gemma.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043
  1. # coding=utf-8
  2. # Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
  18. import sentencepiece as spm
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, StaticCache
  24. from ...configuration_utils import PretrainedConfig
  25. from ...modeling_flash_attention_utils import _flash_attention_forward
  26. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  27. from ...pytorch_utils import ALL_LAYERNORM_LAYERS
  28. from ...tokenization_utils import AddedToken, PreTrainedTokenizer
  29. from ...utils import logging
  30. from ..llama.modeling_llama import (
  31. LlamaDecoderLayer,
  32. LlamaFlashAttention2,
  33. LlamaForCausalLM,
  34. LlamaForSequenceClassification,
  35. LlamaForTokenClassification,
  36. LlamaModel,
  37. apply_rotary_pos_emb,
  38. repeat_kv,
  39. )
  40. from ..llama.tokenization_llama import LlamaTokenizer
  41. if TYPE_CHECKING:
  42. from ...tokenization_utils_base import TextInput
  43. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  44. SPIECE_UNDERLINE = "▁"
  45. _CHECKPOINT_FOR_DOC = "google/gemma-7b"
  46. logger = logging.get_logger(__name__)
  47. class GemmaConfig(PretrainedConfig):
  48. r"""
  49. This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
  50. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  51. defaults will yield a similar configuration to that of the Gemma-7B.
  52. e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
  53. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  54. documentation from [`PretrainedConfig`] for more information.
  55. Args:
  56. vocab_size (`int`, *optional*, defaults to 256000):
  57. Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
  58. `inputs_ids` passed when calling [`GemmaModel`]
  59. hidden_size (`int`, *optional*, defaults to 3072):
  60. Dimension of the hidden representations.
  61. intermediate_size (`int`, *optional*, defaults to 24576):
  62. Dimension of the MLP representations.
  63. num_hidden_layers (`int`, *optional*, defaults to 28):
  64. Number of hidden layers in the Transformer decoder.
  65. num_attention_heads (`int`, *optional*, defaults to 16):
  66. Number of attention heads for each attention layer in the Transformer decoder.
  67. num_key_value_heads (`int`, *optional*, defaults to 16):
  68. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  69. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  70. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  71. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  72. by meanpooling all the original heads within that group. For more details checkout [this
  73. paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
  74. `num_attention_heads`.
  75. head_dim (`int`, *optional*, defaults to 256):
  76. The attention head dimension.
  77. hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
  78. The legacy activation function. It is overwritten by the `hidden_activation`.
  79. hidden_activation (`str` or `function`, *optional*):
  80. The non-linear activation function (function or string) in the decoder. Will default to `"gelu_pytorch_tanh"`
  81. if not specified. `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` activation function.
  82. max_position_embeddings (`int`, *optional*, defaults to 8192):
  83. The maximum sequence length that this model might ever be used with.
  84. initializer_range (`float`, *optional*, defaults to 0.02):
  85. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  86. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  87. The epsilon used by the rms normalization layers.
  88. use_cache (`bool`, *optional*, defaults to `True`):
  89. Whether or not the model should return the last key/values attentions (not used by all models). Only
  90. relevant if `config.is_decoder=True`.
  91. pad_token_id (`int`, *optional*, defaults to 0):
  92. Padding token id.
  93. eos_token_id (`int`, *optional*, defaults to 1):
  94. End of stream token id.
  95. bos_token_id (`int`, *optional*, defaults to 2):
  96. Beginning of stream token id.
  97. tie_word_embeddings (`bool`, *optional*, defaults to `True`):
  98. Whether to tie weight embeddings
  99. rope_theta (`float`, *optional*, defaults to 10000.0):
  100. The base period of the RoPE embeddings.
  101. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
  102. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  103. attention_dropout (`float`, *optional*, defaults to 0.0):
  104. The dropout ratio for the attention probabilities.
  105. ```python
  106. >>> from transformers import GemmaModel, GemmaConfig
  107. >>> # Initializing a Gemma gemma-7b style configuration
  108. >>> configuration = GemmaConfig()
  109. >>> # Initializing a model from the gemma-7b style configuration
  110. >>> model = GemmaModel(configuration)
  111. >>> # Accessing the model configuration
  112. >>> configuration = model.config
  113. ```"""
  114. model_type = "gemma"
  115. keys_to_ignore_at_inference = ["past_key_values"]
  116. def __init__(
  117. self,
  118. vocab_size=256000,
  119. hidden_size=3072,
  120. intermediate_size=24576,
  121. num_hidden_layers=28,
  122. num_attention_heads=16,
  123. num_key_value_heads=16,
  124. head_dim=256,
  125. hidden_act="gelu_pytorch_tanh",
  126. hidden_activation=None,
  127. max_position_embeddings=8192,
  128. initializer_range=0.02,
  129. rms_norm_eps=1e-6,
  130. use_cache=True,
  131. pad_token_id=0,
  132. eos_token_id=1,
  133. bos_token_id=2,
  134. tie_word_embeddings=True,
  135. rope_theta=10000.0,
  136. attention_bias=False,
  137. attention_dropout=0.0,
  138. **kwargs,
  139. ):
  140. self.vocab_size = vocab_size
  141. self.max_position_embeddings = max_position_embeddings
  142. self.hidden_size = hidden_size
  143. self.intermediate_size = intermediate_size
  144. self.num_hidden_layers = num_hidden_layers
  145. self.num_attention_heads = num_attention_heads
  146. self.head_dim = head_dim
  147. self.num_key_value_heads = num_key_value_heads
  148. self.hidden_act = hidden_act
  149. self.hidden_activation = hidden_activation
  150. self.initializer_range = initializer_range
  151. self.rms_norm_eps = rms_norm_eps
  152. self.use_cache = use_cache
  153. self.rope_theta = rope_theta
  154. self.attention_bias = attention_bias
  155. self.attention_dropout = attention_dropout
  156. super().__init__(
  157. pad_token_id=pad_token_id,
  158. bos_token_id=bos_token_id,
  159. eos_token_id=eos_token_id,
  160. tie_word_embeddings=tie_word_embeddings,
  161. **kwargs,
  162. )
  163. class GemmaTokenizer(LlamaTokenizer, PreTrainedTokenizer):
  164. """
  165. Construct a Gemma tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
  166. no padding token in the original model.
  167. Args:
  168. vocab_file (`str`):
  169. Path to the vocabulary file.
  170. unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<unk>"`):
  171. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  172. token instead.
  173. bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<bos>"`):
  174. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  175. eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<eos>"`):
  176. The end of sequence token.
  177. pad_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `"<pad>"`):
  178. A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by
  179. attention mechanisms or loss computation.
  180. sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*):
  181. Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  182. SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  183. to set:
  184. - `enable_sampling`: Enable subword regularization.
  185. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
  186. - `nbest_size = {0,1}`: No sampling is performed.
  187. - `nbest_size > 1`: samples from the nbest_size results.
  188. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
  189. using forward-filtering-and-backward-sampling algorithm.
  190. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
  191. BPE-dropout.
  192. add_bos_token (`bool`, *optional*, defaults to `True`):
  193. Whether or not to add an `bos_token` at the start of sequences.
  194. add_eos_token (`bool`, *optional*, defaults to `False`):
  195. Whether or not to add an `eos_token` at the end of sequences.
  196. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  197. Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
  198. extra spaces.
  199. use_default_system_prompt (`bool`, *optional*, defaults to `False`):
  200. Whether or not the default system prompt for Gemma should be used.
  201. spaces_between_special_tokens (`bool`, *optional*, defaults to `False`):
  202. Whether or not to add spaces between special tokens.
  203. """
  204. def __init__(
  205. self,
  206. vocab_file,
  207. unk_token="<unk>",
  208. bos_token="<bos>",
  209. eos_token="<eos>",
  210. pad_token="<pad>",
  211. sp_model_kwargs: Optional[Dict[str, Any]] = None,
  212. add_bos_token=True,
  213. add_eos_token=False,
  214. clean_up_tokenization_spaces=False,
  215. use_default_system_prompt=False,
  216. spaces_between_special_tokens=False,
  217. **kwargs,
  218. ):
  219. self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
  220. bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token
  221. eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token
  222. unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token
  223. pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token
  224. self.vocab_file = vocab_file
  225. self.add_bos_token = add_bos_token
  226. self.add_eos_token = add_eos_token
  227. self.use_default_system_prompt = use_default_system_prompt
  228. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  229. self.sp_model.Load(vocab_file)
  230. PreTrainedTokenizer.__init__(
  231. self,
  232. bos_token=bos_token,
  233. eos_token=eos_token,
  234. unk_token=unk_token,
  235. pad_token=pad_token,
  236. add_bos_token=add_bos_token,
  237. add_eos_token=add_eos_token,
  238. sp_model_kwargs=sp_model_kwargs,
  239. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  240. use_default_system_prompt=use_default_system_prompt,
  241. spaces_between_special_tokens=spaces_between_special_tokens,
  242. **kwargs,
  243. )
  244. def get_spm_processor(self):
  245. raise AttributeError("Not needed for Gemma")
  246. def unk_token_length(self):
  247. raise AttributeError("Not needed for Gemma")
  248. def tokenize(self, text: "TextInput", **kwargs) -> List[str]:
  249. """
  250. Args:
  251. text: TextInput
  252. Simply calls PreTrainedTokenizer's method
  253. """
  254. return PreTrainedTokenizer.tokenize(self, text, **kwargs)
  255. def _tokenize(self, text, **kwargs):
  256. """
  257. Args:
  258. text: TextInput
  259. Returns a tokenized string. The Gemma tokenizer never adds a prefix space.
  260. """
  261. return self.sp_model.encode(text, out_type=str)
  262. def _decode(
  263. self,
  264. token_ids: List[int],
  265. skip_special_tokens: bool = False,
  266. spaces_between_special_tokens: bool = False,
  267. **kwargs,
  268. ) -> str:
  269. sub_texts = []
  270. current_sub_text = []
  271. for ids in token_ids:
  272. if skip_special_tokens and ids in self.all_special_ids:
  273. continue
  274. if ids in self._added_tokens_decoder:
  275. if current_sub_text:
  276. sub_texts.append(self.sp_model.decode(current_sub_text))
  277. sub_texts.append(self._added_tokens_decoder[ids].content)
  278. current_sub_text = []
  279. else:
  280. current_sub_text.append(ids)
  281. if current_sub_text:
  282. sub_texts.append(self.sp_model.decode(current_sub_text))
  283. if spaces_between_special_tokens:
  284. sub_texts = " ".join(sub_texts)
  285. else:
  286. sub_texts = "".join(sub_texts)
  287. return sub_texts.replace(SPIECE_UNDERLINE, " ")
  288. def convert_tokens_to_string(self, tokens):
  289. """Converts a sequence of tokens (string) in a single string."""
  290. current_sub_tokens = []
  291. out_string = ""
  292. for token in tokens:
  293. # make sure that special tokens are not decoded using sentencepiece model
  294. if token in self._added_tokens_encoder:
  295. out_string += self.sp_model.decode(current_sub_tokens) + token
  296. current_sub_tokens = []
  297. else:
  298. current_sub_tokens.append(token)
  299. out_string += self.sp_model.decode(current_sub_tokens)
  300. return out_string
  301. class GemmaRMSNorm(nn.Module):
  302. def __init__(self, dim: int, eps: float = 1e-6):
  303. super().__init__()
  304. self.eps = eps
  305. self.weight = nn.Parameter(torch.zeros(dim))
  306. def _norm(self, x):
  307. return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
  308. def forward(self, x):
  309. output = self._norm(x.float())
  310. # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
  311. # See https://github.com/huggingface/transformers/pull/29402
  312. output = output * (1.0 + self.weight.float())
  313. return output.type_as(x)
  314. def extra_repr(self):
  315. return f"{tuple(self.weight.shape)}, eps={self.eps}"
  316. ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm)
  317. class GemmaRotaryEmbedding(nn.Module):
  318. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  319. super().__init__()
  320. self.dim = dim
  321. self.max_position_embeddings = max_position_embeddings
  322. self.base = base
  323. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  324. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  325. @torch.no_grad()
  326. def forward(self, x, position_ids, seq_len=None):
  327. # x: [bs, num_attention_heads, seq_len, head_size]
  328. self.inv_freq.to(x.device)
  329. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  330. position_ids_expanded = position_ids[:, None, :].float()
  331. # Force float32 since bfloat16 loses precision on long contexts
  332. # See https://github.com/huggingface/transformers/pull/29285
  333. device_type = x.device.type
  334. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  335. with torch.autocast(device_type=device_type, enabled=False):
  336. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  337. emb = torch.cat((freqs, freqs), dim=-1)
  338. cos = emb.cos()
  339. sin = emb.sin()
  340. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  341. class GemmaLinearScalingRotaryEmbedding(GemmaRotaryEmbedding):
  342. """GemmaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
  343. def forward(self, x, position_ids):
  344. # difference to the original RoPE: a scaling factor is aplied to the position ids
  345. position_ids = position_ids.float() / self.scaling_factor
  346. cos, sin = super().forward(x, position_ids)
  347. return cos, sin
  348. class GemmaDynamicNTKScalingRotaryEmbedding(GemmaRotaryEmbedding):
  349. """GemmaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
  350. def forward(self, x, position_ids):
  351. # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
  352. seq_len = torch.max(position_ids) + 1
  353. if seq_len > self.max_position_embeddings:
  354. base = self.base * (
  355. (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
  356. ) ** (self.dim / (self.dim - 2))
  357. inv_freq = 1.0 / (
  358. base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
  359. )
  360. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
  361. cos, sin = super().forward(x, position_ids)
  362. return cos, sin
  363. class GemmaMLP(nn.Module):
  364. def __init__(self, config):
  365. super().__init__()
  366. self.config = config
  367. self.hidden_size = config.hidden_size
  368. self.intermediate_size = config.intermediate_size
  369. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  370. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  371. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  372. if config.hidden_activation is None:
  373. logger.warning_once(
  374. "`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.\n"
  375. "Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use\n"
  376. "`config.hidden_activation` if you want to override this behaviour.\n"
  377. "See https://github.com/huggingface/transformers/pull/29402 for more details."
  378. )
  379. config.hidden_activation = "gelu_pytorch_tanh"
  380. hidden_activation = config.hidden_activation
  381. self.act_fn = ACT2FN[hidden_activation]
  382. def forward(self, x):
  383. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  384. class GemmaAttention(nn.Module):
  385. """Multi-headed attention from 'Attention Is All You Need' paper"""
  386. def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None):
  387. super().__init__()
  388. self.config = config
  389. self.layer_idx = layer_idx
  390. if layer_idx is None:
  391. logger.warning_once(
  392. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  393. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  394. "when creating this class."
  395. )
  396. self.attention_dropout = config.attention_dropout
  397. self.hidden_size = config.hidden_size
  398. self.num_heads = config.num_attention_heads
  399. self.head_dim = config.head_dim
  400. self.num_key_value_heads = config.num_key_value_heads
  401. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  402. self.max_position_embeddings = config.max_position_embeddings
  403. self.rope_theta = config.rope_theta
  404. self.is_causal = True
  405. self.scaling = 1 / math.sqrt(config.head_dim)
  406. if self.hidden_size % self.num_heads != 0:
  407. raise ValueError(
  408. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  409. f" and `num_heads`: {self.num_heads})."
  410. )
  411. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  412. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  413. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  414. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
  415. self.rotary_emb = GemmaRotaryEmbedding(
  416. self.head_dim,
  417. max_position_embeddings=self.max_position_embeddings,
  418. base=self.rope_theta,
  419. )
  420. def forward(
  421. self,
  422. hidden_states: torch.Tensor,
  423. attention_mask: Optional[torch.Tensor] = None,
  424. position_ids: Optional[torch.LongTensor] = None,
  425. past_key_value: Optional[Cache] = None,
  426. output_attentions: bool = False,
  427. use_cache: bool = False,
  428. cache_position: Optional[torch.LongTensor] = None,
  429. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  430. bsz, q_len, _ = hidden_states.size()
  431. query_states = self.q_proj(hidden_states)
  432. key_states = self.k_proj(hidden_states)
  433. value_states = self.v_proj(hidden_states)
  434. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  435. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  436. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  437. cos, sin = self.rotary_emb(value_states, position_ids)
  438. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  439. if past_key_value is not None:
  440. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  441. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  442. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  443. key_states = repeat_kv(key_states, self.num_key_value_groups)
  444. value_states = repeat_kv(value_states, self.num_key_value_groups)
  445. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  446. if attention_mask is not None: # no matter the length, we just slice it
  447. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  448. attn_weights = attn_weights + causal_mask
  449. # upcast attention to fp32
  450. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  451. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  452. attn_output = torch.matmul(attn_weights, value_states)
  453. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  454. raise ValueError(
  455. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  456. f" {attn_output.size()}"
  457. )
  458. attn_output = attn_output.transpose(1, 2).contiguous()
  459. attn_output = attn_output.view(bsz, q_len, -1)
  460. attn_output = self.o_proj(attn_output)
  461. if not output_attentions:
  462. attn_weights = None
  463. return attn_output, attn_weights, past_key_value
  464. class GemmaSdpaAttention(GemmaAttention):
  465. """
  466. Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  467. `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  468. SDPA API.
  469. """
  470. # Adapted from GemmaAttention.forward
  471. def forward(
  472. self,
  473. hidden_states: torch.Tensor,
  474. attention_mask: Optional[torch.Tensor] = None,
  475. position_ids: Optional[torch.LongTensor] = None,
  476. past_key_value: Optional[Cache] = None,
  477. output_attentions: bool = False,
  478. use_cache: bool = False,
  479. cache_position: Optional[torch.LongTensor] = None,
  480. **kwargs,
  481. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  482. if output_attentions:
  483. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  484. logger.warning_once(
  485. "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  486. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  487. )
  488. return super().forward(
  489. hidden_states=hidden_states,
  490. attention_mask=attention_mask,
  491. position_ids=position_ids,
  492. past_key_value=past_key_value,
  493. output_attentions=output_attentions,
  494. use_cache=use_cache,
  495. cache_position=cache_position,
  496. )
  497. bsz, q_len, _ = hidden_states.size()
  498. query_states = self.q_proj(hidden_states)
  499. key_states = self.k_proj(hidden_states)
  500. value_states = self.v_proj(hidden_states)
  501. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  502. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  503. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  504. cos, sin = self.rotary_emb(value_states, position_ids)
  505. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  506. if past_key_value is not None:
  507. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  508. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  509. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  510. key_states = repeat_kv(key_states, self.num_key_value_groups)
  511. value_states = repeat_kv(value_states, self.num_key_value_groups)
  512. causal_mask = attention_mask
  513. if attention_mask is not None:
  514. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  515. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  516. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  517. if query_states.device.type == "cuda" and causal_mask is not None:
  518. query_states = query_states.contiguous()
  519. key_states = key_states.contiguous()
  520. value_states = value_states.contiguous()
  521. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  522. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  523. is_causal = True if causal_mask is None and q_len > 1 else False
  524. attn_output = torch.nn.functional.scaled_dot_product_attention(
  525. query_states,
  526. key_states,
  527. value_states,
  528. attn_mask=causal_mask,
  529. dropout_p=self.attention_dropout if self.training else 0.0,
  530. is_causal=is_causal,
  531. )
  532. attn_output = attn_output.transpose(1, 2).contiguous()
  533. attn_output = attn_output.view(bsz, q_len, -1)
  534. attn_output = self.o_proj(attn_output)
  535. return attn_output, None, past_key_value
  536. class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
  537. """
  538. Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
  539. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  540. flash attention and deal with padding tokens in case the input contains any of them.
  541. """
  542. def forward(
  543. self,
  544. hidden_states: torch.Tensor,
  545. attention_mask: Optional[torch.LongTensor] = None,
  546. position_ids: Optional[torch.LongTensor] = None,
  547. past_key_value: Optional[Cache] = None,
  548. output_attentions: bool = False,
  549. use_cache: bool = False,
  550. cache_position: Optional[torch.LongTensor] = None,
  551. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  552. if isinstance(past_key_value, StaticCache):
  553. raise ValueError(
  554. "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
  555. "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
  556. )
  557. output_attentions = False
  558. bsz, q_len, _ = hidden_states.size()
  559. query_states = self.q_proj(hidden_states)
  560. key_states = self.k_proj(hidden_states)
  561. value_states = self.v_proj(hidden_states)
  562. # Flash attention requires the input to have the shape
  563. # batch_size x seq_length x head_dim x hidden_dim
  564. # therefore we just need to keep the original shape
  565. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  566. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  567. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  568. cos, sin = self.rotary_emb(value_states, position_ids)
  569. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  570. if past_key_value is not None:
  571. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  572. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  573. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  574. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  575. # to be able to avoid many of these transpose/reshape/view.
  576. query_states = query_states.transpose(1, 2)
  577. key_states = key_states.transpose(1, 2)
  578. value_states = value_states.transpose(1, 2)
  579. dropout_rate = self.attention_dropout if self.training else 0.0
  580. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  581. # therefore the input hidden states gets silently casted in float32. Hence, we need
  582. # cast them back in the correct dtype just to be sure everything works as expected.
  583. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  584. # in fp32. (GemmaRMSNorm handles it correctly)
  585. input_dtype = query_states.dtype
  586. if input_dtype == torch.float32:
  587. if torch.is_autocast_enabled():
  588. target_dtype = torch.get_autocast_gpu_dtype()
  589. # Handle the case where the model is quantized
  590. elif hasattr(self.config, "_pre_quantization_dtype"):
  591. target_dtype = self.config._pre_quantization_dtype
  592. else:
  593. target_dtype = self.q_proj.weight.dtype
  594. logger.warning_once(
  595. f"The input hidden states seems to be silently casted in float32, this might be related to"
  596. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  597. f" {target_dtype}."
  598. )
  599. query_states = query_states.to(target_dtype)
  600. key_states = key_states.to(target_dtype)
  601. value_states = value_states.to(target_dtype)
  602. attn_output = _flash_attention_forward(
  603. query_states,
  604. key_states,
  605. value_states,
  606. attention_mask,
  607. q_len,
  608. position_ids=position_ids,
  609. dropout=dropout_rate,
  610. sliding_window=getattr(self, "sliding_window", None),
  611. is_causal=self.is_causal,
  612. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  613. )
  614. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  615. attn_output = self.o_proj(attn_output)
  616. if not output_attentions:
  617. attn_weights = None
  618. return attn_output, attn_weights, past_key_value
  619. GEMMA_ATTENTION_CLASSES = {
  620. "eager": GemmaAttention,
  621. "flash_attention_2": GemmaFlashAttention2,
  622. "sdpa": GemmaSdpaAttention,
  623. }
  624. class GemmaDecoderLayer(LlamaDecoderLayer):
  625. def __init__(self, config: GemmaConfig, layer_idx: int):
  626. super().__init__(config)
  627. self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  628. self.mlp = GemmaMLP(config)
  629. self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  630. self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  631. def forward(
  632. self,
  633. hidden_states: torch.Tensor,
  634. attention_mask: Optional[torch.Tensor] = None,
  635. position_ids: Optional[torch.LongTensor] = None,
  636. past_key_value: Optional[Cache] = None,
  637. output_attentions: Optional[bool] = False,
  638. use_cache: Optional[bool] = False,
  639. cache_position: Optional[torch.LongTensor] = None,
  640. **kwargs,
  641. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  642. """
  643. Args:
  644. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  645. attention_mask (`torch.FloatTensor`, *optional*):
  646. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  647. query_sequence_length, key_sequence_length)` if default attention is used.
  648. output_attentions (`bool`, *optional*):
  649. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  650. returned tensors for more detail.
  651. use_cache (`bool`, *optional*):
  652. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  653. (see `past_key_values`).
  654. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  655. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  656. Indices depicting the position of the input sequence tokens in the sequence
  657. kwargs (`dict`, *optional*):
  658. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  659. into the model
  660. """
  661. residual = hidden_states
  662. hidden_states = self.input_layernorm(hidden_states)
  663. # Self Attention
  664. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  665. hidden_states=hidden_states,
  666. attention_mask=attention_mask,
  667. position_ids=position_ids,
  668. past_key_value=past_key_value,
  669. output_attentions=output_attentions,
  670. use_cache=use_cache,
  671. cache_position=cache_position,
  672. **kwargs,
  673. )
  674. hidden_states = residual + hidden_states
  675. # Fully Connected
  676. residual = hidden_states
  677. hidden_states = self.post_attention_layernorm(hidden_states)
  678. hidden_states = self.mlp(hidden_states)
  679. hidden_states = residual + hidden_states
  680. outputs = (hidden_states,)
  681. if output_attentions:
  682. outputs += (self_attn_weights,)
  683. if use_cache:
  684. outputs += (present_key_value,)
  685. return outputs
  686. class GemmaModel(LlamaModel):
  687. def __init__(self, config: GemmaConfig):
  688. super().__init__(config)
  689. self.layers = nn.ModuleList(
  690. [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  691. )
  692. self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  693. del self.rotary_emb # Gemma does not implement rotary emb at the modeling level yet!
  694. self.post_init()
  695. def forward(
  696. self,
  697. input_ids: torch.LongTensor = None,
  698. attention_mask: Optional[torch.Tensor] = None,
  699. position_ids: Optional[torch.LongTensor] = None,
  700. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  701. inputs_embeds: Optional[torch.FloatTensor] = None,
  702. use_cache: Optional[bool] = None,
  703. output_attentions: Optional[bool] = None,
  704. output_hidden_states: Optional[bool] = None,
  705. return_dict: Optional[bool] = None,
  706. cache_position: Optional[torch.LongTensor] = None,
  707. ) -> Union[Tuple, BaseModelOutputWithPast]:
  708. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  709. output_hidden_states = (
  710. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  711. )
  712. use_cache = use_cache if use_cache is not None else self.config.use_cache
  713. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  714. if (input_ids is None) ^ (inputs_embeds is not None):
  715. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  716. if self.gradient_checkpointing and self.training and use_cache:
  717. logger.warning_once(
  718. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  719. )
  720. use_cache = False
  721. if inputs_embeds is None:
  722. inputs_embeds = self.embed_tokens(input_ids)
  723. # kept for BC (non `Cache` `past_key_values` inputs)
  724. return_legacy_cache = False # noqa: F841
  725. if use_cache and not isinstance(past_key_values, Cache):
  726. return_legacy_cache = True # noqa: F841
  727. if past_key_values is None:
  728. past_key_values = DynamicCache()
  729. else:
  730. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  731. logger.warning_once(
  732. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  733. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  734. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  735. )
  736. if cache_position is None:
  737. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  738. cache_position = torch.arange(
  739. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  740. )
  741. if position_ids is None:
  742. position_ids = cache_position.unsqueeze(0)
  743. causal_mask = self._update_causal_mask(
  744. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  745. )
  746. # embed positions
  747. hidden_states = inputs_embeds
  748. # normalized
  749. # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5
  750. # See https://github.com/huggingface/transformers/pull/29402
  751. normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
  752. hidden_states = hidden_states * normalizer
  753. # decoder layers
  754. all_hidden_states = () if output_hidden_states else None
  755. all_self_attns = () if output_attentions else None
  756. next_decoder_cache = None
  757. for decoder_layer in self.layers:
  758. if output_hidden_states:
  759. all_hidden_states += (hidden_states,)
  760. if self.gradient_checkpointing and self.training:
  761. layer_outputs = self._gradient_checkpointing_func(
  762. decoder_layer.__call__,
  763. hidden_states,
  764. causal_mask,
  765. position_ids,
  766. past_key_values,
  767. output_attentions,
  768. use_cache,
  769. cache_position,
  770. )
  771. else:
  772. layer_outputs = decoder_layer(
  773. hidden_states,
  774. attention_mask=causal_mask,
  775. position_ids=position_ids,
  776. past_key_value=past_key_values,
  777. output_attentions=output_attentions,
  778. use_cache=use_cache,
  779. cache_position=cache_position,
  780. )
  781. hidden_states = layer_outputs[0]
  782. if use_cache:
  783. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  784. if output_attentions:
  785. all_self_attns += (layer_outputs[1],)
  786. hidden_states = self.norm(hidden_states)
  787. # add hidden states from the last decoder layer
  788. if output_hidden_states:
  789. all_hidden_states += (hidden_states,)
  790. next_cache = next_decoder_cache if use_cache else None
  791. if return_legacy_cache:
  792. next_cache = next_cache.to_legacy_cache()
  793. if not return_dict:
  794. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  795. return BaseModelOutputWithPast(
  796. last_hidden_state=hidden_states,
  797. past_key_values=next_cache,
  798. hidden_states=all_hidden_states,
  799. attentions=all_self_attns,
  800. )
  801. # Example where we ony modify the docstring and call super
  802. class GemmaForCausalLM(LlamaForCausalLM):
  803. def __init__(self, config):
  804. super().__init__(config)
  805. self.model = GemmaModel(config)
  806. self.post_init()
  807. def forward(
  808. self,
  809. input_ids: torch.LongTensor = None,
  810. attention_mask: Optional[torch.Tensor] = None,
  811. position_ids: Optional[torch.LongTensor] = None,
  812. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  813. inputs_embeds: Optional[torch.FloatTensor] = None,
  814. labels: Optional[torch.LongTensor] = None,
  815. use_cache: Optional[bool] = None,
  816. output_attentions: Optional[bool] = None,
  817. output_hidden_states: Optional[bool] = None,
  818. return_dict: Optional[bool] = None,
  819. cache_position: Optional[torch.LongTensor] = None,
  820. num_logits_to_keep: int = 0,
  821. **loss_kwargs,
  822. ) -> Union[Tuple, CausalLMOutputWithPast]:
  823. r"""
  824. ```python
  825. >>> from transformers import AutoTokenizer, GemmaForCausalLM
  826. >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
  827. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
  828. >>> prompt = "What is your favorite condiment?"
  829. >>> inputs = tokenizer(prompt, return_tensors="pt")
  830. >>> # Generate
  831. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  832. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  833. "What is your favorite condiment?"
  834. ```"""
  835. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  836. output_hidden_states = (
  837. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  838. )
  839. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  840. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  841. outputs = self.model(
  842. input_ids=input_ids,
  843. attention_mask=attention_mask,
  844. position_ids=position_ids,
  845. past_key_values=past_key_values,
  846. inputs_embeds=inputs_embeds,
  847. use_cache=use_cache,
  848. output_attentions=output_attentions,
  849. output_hidden_states=output_hidden_states,
  850. return_dict=return_dict,
  851. cache_position=cache_position,
  852. )
  853. hidden_states = outputs[0]
  854. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  855. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  856. loss = None
  857. if labels is not None:
  858. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  859. if not return_dict:
  860. output = (logits,) + outputs[1:]
  861. return (loss,) + output if loss is not None else output
  862. return CausalLMOutputWithPast(
  863. loss=loss,
  864. logits=logits,
  865. past_key_values=outputs.past_key_values,
  866. hidden_states=outputs.hidden_states,
  867. attentions=outputs.attentions,
  868. )
  869. class GemmaForSequenceClassification(LlamaForSequenceClassification):
  870. def __init__(self, config):
  871. super().__init__(config)
  872. self.model = GemmaModel(config)
  873. self.post_init()
  874. class GemmaForTokenClassification(LlamaForTokenClassification):
  875. def __init__(self, config):
  876. super().__init__(config)
  877. self.model = GemmaModel(config)
  878. self.post_init()
  879. __all__ = [
  880. "GemmaConfig",
  881. "GemmaTokenizer",
  882. "GemmaModel",
  883. "GemmaForCausalLM",
  884. "GemmaForSequenceClassification",
  885. "GemmaForTokenClassification",
  886. ]