modular_glm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # coding=utf-8
  2. # Copyright 2024 The GLM & ZhipuAI team and HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. from typing import Optional
  18. import torch
  19. import torch.nn as nn
  20. import torch.utils.checkpoint
  21. from ...utils import logging
  22. from ..gemma.modeling_gemma import (
  23. GemmaForCausalLM,
  24. GemmaForSequenceClassification,
  25. GemmaForTokenClassification,
  26. )
  27. from ..granite.modeling_granite import (
  28. GraniteAttention,
  29. GraniteFlashAttention2,
  30. GraniteSdpaAttention,
  31. )
  32. from ..llama.modeling_llama import (
  33. LlamaDecoderLayer,
  34. LlamaModel,
  35. LlamaPreTrainedModel,
  36. )
  37. from ..phi3.modeling_phi3 import (
  38. Phi3MLP,
  39. Phi3RMSNorm,
  40. Phi3RotaryEmbedding,
  41. )
  42. from .configuration_glm import GlmConfig
  43. logger = logging.get_logger(__name__)
  44. class GlmRMSNorm(Phi3RMSNorm):
  45. pass
  46. class GlmRotaryEmbedding(Phi3RotaryEmbedding):
  47. pass
  48. class GlmMLP(Phi3MLP):
  49. pass
  50. def rotate_half(x):
  51. """Rotates half the hidden dims of the input."""
  52. x1 = x[..., 0::2]
  53. x2 = x[..., 1::2]
  54. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  55. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  56. """Applies Rotary Position Embedding to the query and key tensors.
  57. Args:
  58. q (`torch.Tensor`): The query tensor.
  59. k (`torch.Tensor`): The key tensor.
  60. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  61. sin (`torch.Tensor`): The sine part of the rotary embedding.
  62. position_ids (`torch.Tensor`, *optional*):
  63. Deprecated and unused.
  64. unsqueeze_dim (`int`, *optional*, defaults to 1):
  65. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  66. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  67. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  68. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  69. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  70. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  71. Returns:
  72. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  73. """
  74. cos = cos.unsqueeze(unsqueeze_dim)
  75. sin = sin.unsqueeze(unsqueeze_dim)
  76. # Interleave them instead of usual shape
  77. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  78. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  79. # Keep half for later concatenation
  80. q, q_pass = q[..., : q.shape[-1] // 2], q[..., q.shape[-1] // 2 :]
  81. k, k_pass = k[..., : k.shape[-1] // 2], k[..., k.shape[-1] // 2 :]
  82. # Apply rotary embeddings on the first half
  83. q_embed = (q * cos) + (rotate_half(q) * sin)
  84. k_embed = (k * cos) + (rotate_half(k) * sin)
  85. # Concatenate back to full shape
  86. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  87. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  88. return q_embed, k_embed
  89. class GlmAttention(GraniteAttention):
  90. def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
  91. super().__init__(config, layer_idx)
  92. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  93. self.scaling = 1 / math.sqrt(self.head_dim)
  94. class GlmFlashAttention2(GlmAttention, GraniteFlashAttention2):
  95. pass
  96. class GlmSdpaAttention(GraniteSdpaAttention):
  97. pass
  98. GLM_ATTENTION_CLASSES = {
  99. "eager": GlmAttention,
  100. "flash_attention_2": GlmFlashAttention2,
  101. "sdpa": GlmSdpaAttention,
  102. }
  103. class GlmDecoderLayer(LlamaDecoderLayer):
  104. def __init__(self, config: GlmConfig, layer_idx: Optional[int] = None):
  105. super().__init__()
  106. self.mlp = GlmMLP(config)
  107. self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  108. self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  109. class GlmPreTrainedModel(LlamaPreTrainedModel):
  110. pass
  111. class GlmModel(GlmPreTrainedModel, LlamaModel):
  112. def __init__(self, config: GlmConfig):
  113. super().__init__(config)
  114. self.layers = nn.ModuleList(
  115. [GlmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  116. )
  117. self.norm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  118. self.rotary_emb = GlmRotaryEmbedding(
  119. dim=config.head_dim // 2, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta
  120. )
  121. self.gradient_checkpointing = False
  122. # Initialize weights and apply final processing
  123. self.post_init()
  124. class GlmForCausalLM(GemmaForCausalLM):
  125. def __init__(self, config: GlmConfig):
  126. super().__init__(config)
  127. self.model = GlmModel(config)
  128. self.post_init()
  129. class GlmForSequenceClassification(GemmaForSequenceClassification):
  130. def __init__(self, config: GlmConfig):
  131. super().__init__(config)
  132. self.model = GlmModel(config)
  133. self.post_init()
  134. class GlmForTokenClassification(GemmaForTokenClassification):
  135. def __init__(self, config: GlmConfig):
  136. super().__init__(config)
  137. self.model = GlmModel(config)
  138. self.post_init()
  139. __all__ = [
  140. "GlmPreTrainedModel",
  141. "GlmModel",
  142. "GlmForCausalLM",
  143. "GlmForSequenceClassification",
  144. "GlmForTokenClassification",
  145. ]