modeling_glm.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm/modular_glm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  9. #
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. import math
  23. from typing import List, Optional, Tuple, Union
  24. import torch
  25. import torch.nn as nn
  26. import torch.utils.checkpoint
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import AttentionMaskConverter
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPast,
  33. CausalLMOutputWithPast,
  34. SequenceClassifierOutputWithPast,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import PreTrainedModel
  38. from ...utils import (
  39. add_start_docstrings,
  40. add_start_docstrings_to_model_forward,
  41. is_flash_attn_2_available,
  42. is_flash_attn_greater_or_equal_2_10,
  43. logging,
  44. replace_return_docstrings,
  45. )
  46. from .configuration_glm import GlmConfig
  47. if is_flash_attn_2_available():
  48. from ...modeling_flash_attention_utils import _flash_attention_forward
  49. from ...modeling_flash_attention_utils import _flash_attention_forward
  50. class GlmRMSNorm(nn.Module):
  51. def __init__(self, hidden_size, eps=1e-6):
  52. """
  53. GlmRMSNorm is equivalent to T5LayerNorm
  54. """
  55. super().__init__()
  56. self.weight = nn.Parameter(torch.ones(hidden_size))
  57. self.variance_epsilon = eps
  58. def forward(self, hidden_states):
  59. input_dtype = hidden_states.dtype
  60. hidden_states = hidden_states.to(torch.float32)
  61. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  62. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  63. return self.weight * hidden_states.to(input_dtype)
  64. def extra_repr(self):
  65. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  66. class GlmRotaryEmbedding(nn.Module):
  67. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  68. super().__init__()
  69. self.dim = dim
  70. self.max_position_embeddings = max_position_embeddings
  71. self.base = base
  72. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  73. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  74. @torch.no_grad()
  75. def forward(self, x, position_ids, seq_len=None):
  76. # x: [bs, num_attention_heads, seq_len, head_size]
  77. self.inv_freq.to(x.device)
  78. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  79. position_ids_expanded = position_ids[:, None, :].float()
  80. # Force float32 since bfloat16 loses precision on long contexts
  81. # See https://github.com/huggingface/transformers/pull/29285
  82. device_type = x.device.type
  83. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  84. with torch.autocast(device_type=device_type, enabled=False):
  85. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  86. emb = torch.cat((freqs, freqs), dim=-1)
  87. cos = emb.cos()
  88. sin = emb.sin()
  89. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  90. class GlmMLP(nn.Module):
  91. def __init__(self, config):
  92. super().__init__()
  93. self.config = config
  94. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  95. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  96. self.activation_fn = ACT2FN[config.hidden_act]
  97. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  98. up_states = self.gate_up_proj(hidden_states)
  99. gate, up_states = up_states.chunk(2, dim=-1)
  100. up_states = up_states * self.activation_fn(gate)
  101. return self.down_proj(up_states)
  102. logger = logging.get_logger(__name__)
  103. def rotate_half(x):
  104. """Rotates half the hidden dims of the input."""
  105. x1 = x[..., 0::2]
  106. x2 = x[..., 1::2]
  107. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  108. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  109. """Applies Rotary Position Embedding to the query and key tensors.
  110. Args:
  111. q (`torch.Tensor`): The query tensor.
  112. k (`torch.Tensor`): The key tensor.
  113. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  114. sin (`torch.Tensor`): The sine part of the rotary embedding.
  115. position_ids (`torch.Tensor`, *optional*):
  116. Deprecated and unused.
  117. unsqueeze_dim (`int`, *optional*, defaults to 1):
  118. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  119. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  120. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  121. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  122. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  123. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  124. Returns:
  125. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  126. """
  127. cos = cos.unsqueeze(unsqueeze_dim)
  128. sin = sin.unsqueeze(unsqueeze_dim)
  129. # Interleave them instead of usual shape
  130. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  131. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  132. # Keep half for later concatenation
  133. q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :]
  134. k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :]
  135. # Apply rotary embeddings on the first half
  136. q_embed = (q * cos) + (rotate_half(q) * sin)
  137. k_embed = (k * cos) + (rotate_half(k) * sin)
  138. # Concatenate back to full shape
  139. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  140. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  141. return q_embed, k_embed
  142. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  143. """
  144. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  145. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  146. """
  147. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  148. if n_rep == 1:
  149. return hidden_states
  150. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  151. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  152. class GlmAttention(nn.Module):
  153. """Multi-headed attention from 'Attention Is All You Need' paper"""
  154. def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
  155. super().__init__()
  156. self.config = config
  157. self.layer_idx = layer_idx
  158. if layer_idx is None:
  159. logger.warning_once(
  160. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  161. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  162. "when creating this class."
  163. )
  164. self.attention_dropout = config.attention_dropout
  165. self.hidden_size = config.hidden_size
  166. self.num_heads = config.num_attention_heads
  167. self.head_dim = self.hidden_size // self.num_heads
  168. self.num_key_value_heads = config.num_key_value_heads
  169. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  170. self.is_causal = True
  171. self.scaling = 1 / math.sqrt(self.head_dim)
  172. if (self.head_dim * self.num_heads) != self.hidden_size:
  173. raise ValueError(
  174. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  175. f" and `num_heads`: {self.num_heads})."
  176. )
  177. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  178. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  179. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  180. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  181. def forward(
  182. self,
  183. hidden_states: torch.Tensor,
  184. attention_mask: Optional[torch.Tensor] = None,
  185. position_ids: Optional[torch.LongTensor] = None,
  186. past_key_value: Optional[Cache] = None,
  187. output_attentions: bool = False,
  188. use_cache: bool = False,
  189. cache_position: Optional[torch.LongTensor] = None,
  190. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
  191. **kwargs,
  192. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  193. bsz, q_len, _ = hidden_states.size()
  194. query_states = self.q_proj(hidden_states)
  195. key_states = self.k_proj(hidden_states)
  196. value_states = self.v_proj(hidden_states)
  197. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  198. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  199. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  200. cos, sin = position_embeddings
  201. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  202. if past_key_value is not None:
  203. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  204. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  205. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  206. key_states = repeat_kv(key_states, self.num_key_value_groups)
  207. value_states = repeat_kv(value_states, self.num_key_value_groups)
  208. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling
  209. if attention_mask is not None: # no matter the length, we just slice it
  210. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  211. attn_weights = attn_weights + causal_mask
  212. # upcast attention to fp32
  213. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  214. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  215. attn_output = torch.matmul(attn_weights, value_states)
  216. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  217. raise ValueError(
  218. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  219. f" {attn_output.size()}"
  220. )
  221. attn_output = attn_output.transpose(1, 2).contiguous()
  222. attn_output = attn_output.view(bsz, q_len, -1)
  223. attn_output = self.o_proj(attn_output)
  224. if not output_attentions:
  225. attn_weights = None
  226. return attn_output, attn_weights, past_key_value
  227. class GlmFlashAttention2(GlmAttention):
  228. """
  229. Glm flash attention module. This module inherits from `GlmAttention` as the weights of the module stays
  230. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  231. flash attention and deal with padding tokens in case the input contains any of them.
  232. """
  233. def __init__(self, *args, **kwargs):
  234. super().__init__(*args, **kwargs)
  235. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  236. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  237. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  238. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  239. def forward(
  240. self,
  241. hidden_states: torch.Tensor,
  242. attention_mask: Optional[torch.LongTensor] = None,
  243. position_ids: Optional[torch.LongTensor] = None,
  244. past_key_value: Optional[Cache] = None,
  245. output_attentions: bool = False,
  246. use_cache: bool = False,
  247. cache_position: Optional[torch.LongTensor] = None,
  248. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
  249. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  250. output_attentions = False
  251. bsz, q_len, _ = hidden_states.size()
  252. query_states = self.q_proj(hidden_states)
  253. key_states = self.k_proj(hidden_states)
  254. value_states = self.v_proj(hidden_states)
  255. # Flash attention requires the input to have the shape
  256. # batch_size x seq_length x head_dim x hidden_dim
  257. # therefore we just need to keep the original shape
  258. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  259. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  260. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  261. cos, sin = position_embeddings
  262. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  263. if past_key_value is not None:
  264. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  265. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  266. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  267. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  268. # to be able to avoid many of these transpose/reshape/view.
  269. query_states = query_states.transpose(1, 2)
  270. key_states = key_states.transpose(1, 2)
  271. value_states = value_states.transpose(1, 2)
  272. dropout_rate = self.attention_dropout if self.training else 0.0
  273. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  274. # therefore the input hidden states gets silently casted in float32. Hence, we need
  275. # cast them back in the correct dtype just to be sure everything works as expected.
  276. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  277. # in fp32. (GlmRMSNorm handles it correctly)
  278. input_dtype = query_states.dtype
  279. if input_dtype == torch.float32:
  280. if torch.is_autocast_enabled():
  281. target_dtype = torch.get_autocast_gpu_dtype()
  282. # Handle the case where the model is quantized
  283. elif hasattr(self.config, "_pre_quantization_dtype"):
  284. target_dtype = self.config._pre_quantization_dtype
  285. else:
  286. target_dtype = self.q_proj.weight.dtype
  287. logger.warning_once(
  288. f"The input hidden states seems to be silently casted in float32, this might be related to"
  289. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  290. f" {target_dtype}."
  291. )
  292. query_states = query_states.to(target_dtype)
  293. key_states = key_states.to(target_dtype)
  294. value_states = value_states.to(target_dtype)
  295. attn_output = _flash_attention_forward(
  296. query_states,
  297. key_states,
  298. value_states,
  299. attention_mask,
  300. q_len,
  301. position_ids=position_ids,
  302. dropout=dropout_rate,
  303. softmax_scale=self.scaling,
  304. sliding_window=getattr(self, "sliding_window", None),
  305. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  306. is_causal=self.is_causal,
  307. )
  308. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  309. attn_output = self.o_proj(attn_output)
  310. if not output_attentions:
  311. attn_weights = None
  312. return attn_output, attn_weights, past_key_value
  313. class GlmSdpaAttention(GlmAttention):
  314. """
  315. Glm attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  316. `GlmAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  317. SDPA API.
  318. """
  319. # Adapted from GlmAttention.forward
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. attention_mask: Optional[torch.Tensor] = None,
  324. position_ids: Optional[torch.LongTensor] = None,
  325. past_key_value: Optional[Cache] = None,
  326. output_attentions: bool = False,
  327. use_cache: bool = False,
  328. cache_position: Optional[torch.LongTensor] = None,
  329. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
  330. **kwargs,
  331. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  332. if output_attentions:
  333. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  334. logger.warning_once(
  335. "GlmModel is using GlmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  336. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  337. )
  338. return super().forward(
  339. hidden_states=hidden_states,
  340. attention_mask=attention_mask,
  341. position_ids=position_ids,
  342. past_key_value=past_key_value,
  343. output_attentions=output_attentions,
  344. use_cache=use_cache,
  345. cache_position=cache_position,
  346. position_embeddings=position_embeddings,
  347. )
  348. bsz, q_len, _ = hidden_states.size()
  349. query_states = self.q_proj(hidden_states)
  350. key_states = self.k_proj(hidden_states)
  351. value_states = self.v_proj(hidden_states)
  352. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  353. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  354. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  355. cos, sin = position_embeddings
  356. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  357. if past_key_value is not None:
  358. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  359. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  360. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  361. key_states = repeat_kv(key_states, self.num_key_value_groups)
  362. value_states = repeat_kv(value_states, self.num_key_value_groups)
  363. causal_mask = attention_mask
  364. if attention_mask is not None:
  365. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  366. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  367. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  368. if query_states.device.type == "cuda" and causal_mask is not None:
  369. query_states = query_states.contiguous()
  370. key_states = key_states.contiguous()
  371. value_states = value_states.contiguous()
  372. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  373. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  374. is_causal = True if causal_mask is None and q_len > 1 else False
  375. attn_output = torch.nn.functional.scaled_dot_product_attention(
  376. query_states,
  377. key_states,
  378. value_states,
  379. attn_mask=causal_mask,
  380. dropout_p=self.attention_dropout if self.training else 0.0,
  381. is_causal=is_causal,
  382. scale=self.scaling,
  383. )
  384. attn_output = attn_output.transpose(1, 2).contiguous()
  385. attn_output = attn_output.view(bsz, q_len, -1)
  386. attn_output = self.o_proj(attn_output)
  387. return attn_output, None, past_key_value
  388. GLM_ATTENTION_CLASSES = {
  389. "eager": GlmAttention,
  390. "flash_attention_2": GlmFlashAttention2,
  391. "sdpa": GlmSdpaAttention,
  392. }
  393. class GlmDecoderLayer(nn.Module):
  394. def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
  395. super().__init__()
  396. self.hidden_size = config.hidden_size
  397. self.self_attn = GLM_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  398. self.mlp = GlmMLP(config)
  399. self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  400. self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  401. def forward(
  402. self,
  403. hidden_states: torch.Tensor,
  404. attention_mask: Optional[torch.Tensor] = None,
  405. position_ids: Optional[torch.LongTensor] = None,
  406. past_key_value: Optional[Cache] = None,
  407. output_attentions: Optional[bool] = False,
  408. use_cache: Optional[bool] = False,
  409. cache_position: Optional[torch.LongTensor] = None,
  410. position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
  411. **kwargs,
  412. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  413. """
  414. Args:
  415. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  416. attention_mask (`torch.FloatTensor`, *optional*):
  417. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  418. query_sequence_length, key_sequence_length)` if default attention is used.
  419. output_attentions (`bool`, *optional*):
  420. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  421. returned tensors for more detail.
  422. use_cache (`bool`, *optional*):
  423. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  424. (see `past_key_values`).
  425. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  426. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  427. Indices depicting the position of the input sequence tokens in the sequence
  428. position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  429. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  430. with `head_dim` being the embedding dimension of each attention head.
  431. kwargs (`dict`, *optional*):
  432. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  433. into the model
  434. """
  435. residual = hidden_states
  436. hidden_states = self.input_layernorm(hidden_states)
  437. # Self Attention
  438. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  439. hidden_states=hidden_states,
  440. attention_mask=attention_mask,
  441. position_ids=position_ids,
  442. past_key_value=past_key_value,
  443. output_attentions=output_attentions,
  444. use_cache=use_cache,
  445. cache_position=cache_position,
  446. position_embeddings=position_embeddings,
  447. **kwargs,
  448. )
  449. hidden_states = residual + hidden_states
  450. # Fully Connected
  451. residual = hidden_states
  452. hidden_states = self.post_attention_layernorm(hidden_states)
  453. hidden_states = self.mlp(hidden_states)
  454. hidden_states = residual + hidden_states
  455. outputs = (hidden_states,)
  456. if output_attentions:
  457. outputs += (self_attn_weights,)
  458. if use_cache:
  459. outputs += (present_key_value,)
  460. return outputs
  461. GLM_START_DOCSTRING = r"""
  462. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  463. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  464. etc.)
  465. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  466. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  467. and behavior.
  468. Parameters:
  469. config ([`GlmConfig`]):
  470. Model configuration class with all the parameters of the model. Initializing with a config file does not
  471. load the weights associated with the model, only the configuration. Check out the
  472. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  473. """
  474. @add_start_docstrings(
  475. "The bare Glm Model outputting raw hidden-states without any specific head on top.",
  476. GLM_START_DOCSTRING,
  477. )
  478. class GlmPreTrainedModel(PreTrainedModel):
  479. config_class = GlmConfig
  480. base_model_prefix = "model"
  481. supports_gradient_checkpointing = True
  482. _no_split_modules = ["GlmDecoderLayer"]
  483. _skip_keys_device_placement = ["past_key_values"]
  484. _supports_flash_attn_2 = True
  485. _supports_sdpa = True
  486. _supports_cache_class = True
  487. _supports_quantized_cache = True
  488. _supports_static_cache = True
  489. def _init_weights(self, module):
  490. std = self.config.initializer_range
  491. if isinstance(module, nn.Linear):
  492. module.weight.data.normal_(mean=0.0, std=std)
  493. if module.bias is not None:
  494. module.bias.data.zero_()
  495. elif isinstance(module, nn.Embedding):
  496. module.weight.data.normal_(mean=0.0, std=std)
  497. if module.padding_idx is not None:
  498. module.weight.data[module.padding_idx].zero_()
  499. _CONFIG_FOR_DOC = "GlmConfig"
  500. GLM_INPUTS_DOCSTRING = r"""
  501. Args:
  502. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  503. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  504. it.
  505. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  506. [`PreTrainedTokenizer.__call__`] for details.
  507. [What are input IDs?](../glossary#input-ids)
  508. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  509. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  510. - 1 for tokens that are **not masked**,
  511. - 0 for tokens that are **masked**.
  512. [What are attention masks?](../glossary#attention-mask)
  513. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  514. [`PreTrainedTokenizer.__call__`] for details.
  515. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  516. `past_key_values`).
  517. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  518. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  519. information on the default strategy.
  520. - 1 indicates the head is **not masked**,
  521. - 0 indicates the head is **masked**.
  522. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  523. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  524. config.n_positions - 1]`.
  525. [What are position IDs?](../glossary#position-ids)
  526. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  527. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  528. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  529. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  530. Two formats are allowed:
  531. - a [`~cache_utils.Cache`] instance, see our
  532. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  533. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  534. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  535. cache format.
  536. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  537. legacy cache format will be returned.
  538. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  539. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  540. of shape `(batch_size, sequence_length)`.
  541. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  542. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  543. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  544. model's internal embedding lookup matrix.
  545. use_cache (`bool`, *optional*):
  546. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  547. `past_key_values`).
  548. output_attentions (`bool`, *optional*):
  549. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  550. tensors for more detail.
  551. output_hidden_states (`bool`, *optional*):
  552. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  553. more detail.
  554. return_dict (`bool`, *optional*):
  555. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  556. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  557. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  558. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  559. the complete sequence length.
  560. """
  561. @add_start_docstrings(
  562. "The bare Glm Model outputting raw hidden-states without any specific head on top.",
  563. GLM_START_DOCSTRING,
  564. )
  565. class GlmModel(GlmPreTrainedModel):
  566. """
  567. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GlmDecoderLayer`]
  568. Args:
  569. config: GlmConfig
  570. """
  571. def __init__(self, config: GlmConfig):
  572. super().__init__(config)
  573. self.padding_idx = config.pad_token_id
  574. self.vocab_size = config.vocab_size
  575. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  576. self.layers = nn.ModuleList(
  577. [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  578. )
  579. self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  580. self.rotary_emb = GlmRotaryEmbedding(
  581. dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
  582. )
  583. self.gradient_checkpointing = False
  584. # Initialize weights and apply final processing
  585. self.post_init()
  586. def get_input_embeddings(self):
  587. return self.embed_tokens
  588. def set_input_embeddings(self, value):
  589. self.embed_tokens = value
  590. @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
  591. def forward(
  592. self,
  593. input_ids: torch.LongTensor = None,
  594. attention_mask: Optional[torch.Tensor] = None,
  595. position_ids: Optional[torch.LongTensor] = None,
  596. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  597. inputs_embeds: Optional[torch.FloatTensor] = None,
  598. use_cache: Optional[bool] = None,
  599. output_attentions: Optional[bool] = None,
  600. output_hidden_states: Optional[bool] = None,
  601. return_dict: Optional[bool] = None,
  602. cache_position: Optional[torch.LongTensor] = None,
  603. ) -> Union[Tuple, BaseModelOutputWithPast]:
  604. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  605. output_hidden_states = (
  606. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  607. )
  608. use_cache = use_cache if use_cache is not None else self.config.use_cache
  609. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  610. if (input_ids is None) ^ (inputs_embeds is not None):
  611. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  612. if self.gradient_checkpointing and self.training and use_cache:
  613. logger.warning_once(
  614. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  615. )
  616. use_cache = False
  617. if inputs_embeds is None:
  618. inputs_embeds = self.embed_tokens(input_ids)
  619. # kept for BC (non `Cache` `past_key_values` inputs)
  620. return_legacy_cache = False
  621. if use_cache and not isinstance(past_key_values, Cache):
  622. return_legacy_cache = True
  623. if past_key_values is None:
  624. past_key_values = DynamicCache()
  625. else:
  626. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  627. logger.warning_once(
  628. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  629. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  630. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  631. )
  632. if cache_position is None:
  633. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  634. cache_position = torch.arange(
  635. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  636. )
  637. if position_ids is None:
  638. position_ids = cache_position.unsqueeze(0)
  639. causal_mask = self._update_causal_mask(
  640. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  641. )
  642. hidden_states = inputs_embeds
  643. # create position embeddings to be shared across the decoder layers
  644. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  645. # decoder layers
  646. all_hidden_states = () if output_hidden_states else None
  647. all_self_attns = () if output_attentions else None
  648. next_decoder_cache = None
  649. for decoder_layer in self.layers:
  650. if output_hidden_states:
  651. all_hidden_states += (hidden_states,)
  652. if self.gradient_checkpointing and self.training:
  653. layer_outputs = self._gradient_checkpointing_func(
  654. decoder_layer.__call__,
  655. hidden_states,
  656. causal_mask,
  657. position_ids,
  658. past_key_values,
  659. output_attentions,
  660. use_cache,
  661. cache_position,
  662. position_embeddings,
  663. )
  664. else:
  665. layer_outputs = decoder_layer(
  666. hidden_states,
  667. attention_mask=causal_mask,
  668. position_ids=position_ids,
  669. past_key_value=past_key_values,
  670. output_attentions=output_attentions,
  671. use_cache=use_cache,
  672. cache_position=cache_position,
  673. position_embeddings=position_embeddings,
  674. )
  675. hidden_states = layer_outputs[0]
  676. if use_cache:
  677. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  678. if output_attentions:
  679. all_self_attns += (layer_outputs[1],)
  680. hidden_states = self.norm(hidden_states)
  681. # add hidden states from the last decoder layer
  682. if output_hidden_states:
  683. all_hidden_states += (hidden_states,)
  684. next_cache = next_decoder_cache if use_cache else None
  685. if return_legacy_cache:
  686. next_cache = next_cache.to_legacy_cache()
  687. if not return_dict:
  688. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  689. return BaseModelOutputWithPast(
  690. last_hidden_state=hidden_states,
  691. past_key_values=next_cache,
  692. hidden_states=all_hidden_states,
  693. attentions=all_self_attns,
  694. )
  695. def _update_causal_mask(
  696. self,
  697. attention_mask: torch.Tensor,
  698. input_tensor: torch.Tensor,
  699. cache_position: torch.Tensor,
  700. past_key_values: Cache,
  701. output_attentions: bool,
  702. ):
  703. if self.config._attn_implementation == "flash_attention_2":
  704. if attention_mask is not None and 0.0 in attention_mask:
  705. return attention_mask
  706. return None
  707. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  708. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  709. # to infer the attention mask.
  710. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  711. using_static_cache = isinstance(past_key_values, StaticCache)
  712. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  713. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  714. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  715. attention_mask,
  716. inputs_embeds=input_tensor,
  717. past_key_values_length=past_seen_tokens,
  718. is_training=self.training,
  719. ):
  720. return None
  721. dtype, device = input_tensor.dtype, input_tensor.device
  722. sequence_length = input_tensor.shape[1]
  723. if using_static_cache:
  724. target_length = past_key_values.get_max_cache_shape()
  725. else:
  726. target_length = (
  727. attention_mask.shape[-1]
  728. if isinstance(attention_mask, torch.Tensor)
  729. else past_seen_tokens + sequence_length + 1
  730. )
  731. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  732. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  733. attention_mask,
  734. sequence_length=sequence_length,
  735. target_length=target_length,
  736. dtype=dtype,
  737. device=device,
  738. cache_position=cache_position,
  739. batch_size=input_tensor.shape[0],
  740. )
  741. if (
  742. self.config._attn_implementation == "sdpa"
  743. and attention_mask is not None
  744. and attention_mask.device.type == "cuda"
  745. and not output_attentions
  746. ):
  747. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  748. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  749. # Details: https://github.com/pytorch/pytorch/issues/110213
  750. min_dtype = torch.finfo(dtype).min
  751. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  752. return causal_mask
  753. @staticmethod
  754. def _prepare_4d_causal_attention_mask_with_cache_position(
  755. attention_mask: torch.Tensor,
  756. sequence_length: int,
  757. target_length: int,
  758. dtype: torch.dtype,
  759. device: torch.device,
  760. cache_position: torch.Tensor,
  761. batch_size: int,
  762. **kwargs,
  763. ):
  764. """
  765. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  766. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  767. Args:
  768. attention_mask (`torch.Tensor`):
  769. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  770. `(batch_size, 1, query_length, key_value_length)`.
  771. sequence_length (`int`):
  772. The sequence length being processed.
  773. target_length (`int`):
  774. The target length: when generating with static cache, the mask should be as long as the static cache,
  775. to account for the 0 padding, the part of the cache that is not filled yet.
  776. dtype (`torch.dtype`):
  777. The dtype to use for the 4D attention mask.
  778. device (`torch.device`):
  779. The device to plcae the 4D attention mask on.
  780. cache_position (`torch.Tensor`):
  781. Indices depicting the position of the input sequence tokens in the sequence.
  782. batch_size (`torch.Tensor`):
  783. Batch size.
  784. """
  785. if attention_mask is not None and attention_mask.dim() == 4:
  786. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  787. causal_mask = attention_mask
  788. else:
  789. min_dtype = torch.finfo(dtype).min
  790. causal_mask = torch.full(
  791. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  792. )
  793. if sequence_length != 1:
  794. causal_mask = torch.triu(causal_mask, diagonal=1)
  795. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  796. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  797. if attention_mask is not None:
  798. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  799. mask_length = attention_mask.shape[-1]
  800. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  801. padding_mask = padding_mask == 0
  802. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  803. padding_mask, min_dtype
  804. )
  805. return causal_mask
  806. class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
  807. _tied_weights_keys = ["lm_head.weight"]
  808. def __init__(self, config: GlmConfig):
  809. super().__init__(config)
  810. self.model = GlmModel(config)
  811. self.vocab_size = config.vocab_size
  812. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  813. # Initialize weights and apply final processing
  814. self.post_init()
  815. def get_input_embeddings(self):
  816. return self.model.embed_tokens
  817. def set_input_embeddings(self, value):
  818. self.model.embed_tokens = value
  819. def get_output_embeddings(self):
  820. return self.lm_head
  821. def set_output_embeddings(self, new_embeddings):
  822. self.lm_head = new_embeddings
  823. def set_decoder(self, decoder):
  824. self.model = decoder
  825. def get_decoder(self):
  826. return self.model
  827. @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
  828. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  829. def forward(
  830. self,
  831. input_ids: torch.LongTensor = None,
  832. attention_mask: Optional[torch.Tensor] = None,
  833. position_ids: Optional[torch.LongTensor] = None,
  834. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  835. inputs_embeds: Optional[torch.FloatTensor] = None,
  836. labels: Optional[torch.LongTensor] = None,
  837. use_cache: Optional[bool] = None,
  838. output_attentions: Optional[bool] = None,
  839. output_hidden_states: Optional[bool] = None,
  840. return_dict: Optional[bool] = None,
  841. cache_position: Optional[torch.LongTensor] = None,
  842. num_logits_to_keep: int = 0,
  843. **loss_kwargs,
  844. ) -> Union[Tuple, CausalLMOutputWithPast]:
  845. r"""
  846. Args:
  847. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  848. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  849. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  850. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  851. num_logits_to_keep (`int`, *optional*):
  852. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  853. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  854. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  855. Returns:
  856. Example:
  857. ```python
  858. >>> from transformers import AutoTokenizer, GlmForCausalLM
  859. >>> model = GlmForCausalLM.from_pretrained("google/glm-7b")
  860. >>> tokenizer = AutoTokenizer.from_pretrained("google/glm-7b")
  861. >>> prompt = "What is your favorite condiment?"
  862. >>> inputs = tokenizer(prompt, return_tensors="pt")
  863. >>> # Generate
  864. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  865. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  866. "What is your favorite condiment?"
  867. ```"""
  868. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  869. output_hidden_states = (
  870. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  871. )
  872. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  873. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  874. outputs = self.model(
  875. input_ids=input_ids,
  876. attention_mask=attention_mask,
  877. position_ids=position_ids,
  878. past_key_values=past_key_values,
  879. inputs_embeds=inputs_embeds,
  880. use_cache=use_cache,
  881. output_attentions=output_attentions,
  882. output_hidden_states=output_hidden_states,
  883. return_dict=return_dict,
  884. cache_position=cache_position,
  885. )
  886. hidden_states = outputs[0]
  887. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  888. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  889. loss = None
  890. if labels is not None:
  891. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  892. if not return_dict:
  893. output = (logits,) + outputs[1:]
  894. return (loss,) + output if loss is not None else output
  895. return CausalLMOutputWithPast(
  896. loss=loss,
  897. logits=logits,
  898. past_key_values=outputs.past_key_values,
  899. hidden_states=outputs.hidden_states,
  900. attentions=outputs.attentions,
  901. )
  902. @add_start_docstrings(
  903. """
  904. The Glm Model transformer with a sequence classification head on top (linear layer).
  905. [`GlmForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  906. (e.g. GPT-2) do.
  907. Since it does classification on the last token, it requires to know the position of the last token. If a
  908. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  909. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  910. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  911. each row of the batch).
  912. """,
  913. GLM_START_DOCSTRING,
  914. )
  915. class GlmForSequenceClassification(GlmPreTrainedModel):
  916. def __init__(self, config: GlmConfig):
  917. super().__init__(config)
  918. self.num_labels = config.num_labels
  919. self.model = GlmModel(config)
  920. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  921. # Initialize weights and apply final processing
  922. self.post_init()
  923. def get_input_embeddings(self):
  924. return self.model.embed_tokens
  925. def set_input_embeddings(self, value):
  926. self.model.embed_tokens = value
  927. @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
  928. def forward(
  929. self,
  930. input_ids: Optional[torch.LongTensor] = None,
  931. attention_mask: Optional[torch.Tensor] = None,
  932. position_ids: Optional[torch.LongTensor] = None,
  933. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  934. inputs_embeds: Optional[torch.FloatTensor] = None,
  935. labels: Optional[torch.LongTensor] = None,
  936. use_cache: Optional[bool] = None,
  937. output_attentions: Optional[bool] = None,
  938. output_hidden_states: Optional[bool] = None,
  939. return_dict: Optional[bool] = None,
  940. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  941. r"""
  942. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  943. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  944. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  945. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  946. """
  947. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  948. transformer_outputs = self.model(
  949. input_ids,
  950. attention_mask=attention_mask,
  951. position_ids=position_ids,
  952. past_key_values=past_key_values,
  953. inputs_embeds=inputs_embeds,
  954. use_cache=use_cache,
  955. output_attentions=output_attentions,
  956. output_hidden_states=output_hidden_states,
  957. return_dict=return_dict,
  958. )
  959. hidden_states = transformer_outputs[0]
  960. logits = self.score(hidden_states)
  961. if input_ids is not None:
  962. batch_size = input_ids.shape[0]
  963. else:
  964. batch_size = inputs_embeds.shape[0]
  965. if self.config.pad_token_id is None and batch_size != 1:
  966. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  967. if self.config.pad_token_id is None:
  968. sequence_lengths = -1
  969. else:
  970. if input_ids is not None:
  971. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  972. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  973. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  974. sequence_lengths = sequence_lengths.to(logits.device)
  975. else:
  976. sequence_lengths = -1
  977. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  978. loss = None
  979. if labels is not None:
  980. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  981. if not return_dict:
  982. output = (pooled_logits,) + transformer_outputs[1:]
  983. return ((loss,) + output) if loss is not None else output
  984. return SequenceClassifierOutputWithPast(
  985. loss=loss,
  986. logits=pooled_logits,
  987. past_key_values=transformer_outputs.past_key_values,
  988. hidden_states=transformer_outputs.hidden_states,
  989. attentions=transformer_outputs.attentions,
  990. )
  991. @add_start_docstrings(
  992. """
  993. The Glm Model transformer with a token classification head on top (a linear layer on top of the hidden-states
  994. output) e.g. for Named-Entity-Recognition (NER) tasks.
  995. """,
  996. GLM_START_DOCSTRING,
  997. )
  998. class GlmForTokenClassification(GlmPreTrainedModel):
  999. def __init__(self, config: GlmConfig):
  1000. super().__init__(config)
  1001. self.num_labels = config.num_labels
  1002. self.model = GlmModel(config)
  1003. if getattr(config, "classifier_dropout", None) is not None:
  1004. classifier_dropout = config.classifier_dropout
  1005. elif getattr(config, "hidden_dropout", None) is not None:
  1006. classifier_dropout = config.hidden_dropout
  1007. else:
  1008. classifier_dropout = 0.1
  1009. self.dropout = nn.Dropout(classifier_dropout)
  1010. self.score = nn.Linear(config.hidden_size, config.num_labels)
  1011. # Initialize weights and apply final processing
  1012. self.post_init()
  1013. def get_input_embeddings(self):
  1014. return self.model.embed_tokens
  1015. def set_input_embeddings(self, value):
  1016. self.model.embed_tokens = value
  1017. @add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING)
  1018. def forward(
  1019. self,
  1020. input_ids: Optional[torch.LongTensor] = None,
  1021. attention_mask: Optional[torch.Tensor] = None,
  1022. position_ids: Optional[torch.LongTensor] = None,
  1023. past_key_values: Optional[List[torch.FloatTensor]] = None,
  1024. inputs_embeds: Optional[torch.FloatTensor] = None,
  1025. labels: Optional[torch.LongTensor] = None,
  1026. use_cache: Optional[bool] = None,
  1027. output_attentions: Optional[bool] = None,
  1028. output_hidden_states: Optional[bool] = None,
  1029. return_dict: Optional[bool] = None,
  1030. ) -> Union[Tuple, TokenClassifierOutput]:
  1031. r"""
  1032. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1033. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1034. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1035. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1036. """
  1037. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1038. outputs = self.model(
  1039. input_ids,
  1040. attention_mask=attention_mask,
  1041. position_ids=position_ids,
  1042. past_key_values=past_key_values,
  1043. inputs_embeds=inputs_embeds,
  1044. use_cache=use_cache,
  1045. output_attentions=output_attentions,
  1046. output_hidden_states=output_hidden_states,
  1047. return_dict=return_dict,
  1048. )
  1049. sequence_output = outputs[0]
  1050. sequence_output = self.dropout(sequence_output)
  1051. logits = self.score(sequence_output)
  1052. loss = None
  1053. if labels is not None:
  1054. loss = self.loss_function(logits, labels, self.config)
  1055. if not return_dict:
  1056. output = (logits,) + outputs[2:]
  1057. return ((loss,) + output) if loss is not None else output
  1058. return TokenClassifierOutput(
  1059. loss=loss,
  1060. logits=logits,
  1061. hidden_states=outputs.hidden_states,
  1062. attentions=outputs.attentions,
  1063. )
  1064. __all__ = [
  1065. "GlmPreTrainedModel",
  1066. "GlmModel",
  1067. "GlmForCausalLM",
  1068. "GlmForSequenceClassification",
  1069. "GlmForTokenClassification",
  1070. ]