modeling_olmo.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142
  1. # coding=utf-8
  2. # Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """PyTorch OLMo model."""
  21. import math
  22. from typing import List, Optional, Tuple, Union
  23. import torch
  24. import torch.nn.functional as F
  25. import torch.utils.checkpoint
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, StaticCache
  29. from ...generation import GenerationMixin
  30. from ...modeling_attn_mask_utils import AttentionMaskConverter
  31. from ...modeling_outputs import (
  32. BaseModelOutputWithPast,
  33. CausalLMOutputWithPast,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...pytorch_utils import ALL_LAYERNORM_LAYERS
  37. from ...utils import (
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. is_flash_attn_2_available,
  41. is_flash_attn_greater_or_equal_2_10,
  42. logging,
  43. replace_return_docstrings,
  44. )
  45. from .configuration_olmo import OlmoConfig
  46. if is_flash_attn_2_available():
  47. from ...modeling_flash_attention_utils import _flash_attention_forward
  48. logger = logging.get_logger(__name__)
  49. _CONFIG_FOR_DOC = "OlmoConfig"
  50. class OlmoLayerNorm(nn.Module):
  51. """LayerNorm but with no learnable weight or bias."""
  52. def __init__(self, hidden_size: int) -> None:
  53. super().__init__()
  54. self.normalized_shape = (hidden_size,)
  55. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  56. orig_dtype = hidden_states.dtype
  57. return F.layer_norm(hidden_states.to(dtype=torch.float32), self.normalized_shape, None, None, eps=1e-5).to(
  58. orig_dtype
  59. )
  60. ALL_LAYERNORM_LAYERS.append(OlmoLayerNorm)
  61. # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Olmo
  62. # TODO(joao): add me back asap :)
  63. class OlmoRotaryEmbedding(nn.Module):
  64. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
  65. super().__init__()
  66. self.scaling_factor = scaling_factor
  67. self.dim = dim
  68. self.max_position_embeddings = max_position_embeddings
  69. self.base = base
  70. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
  71. self.register_buffer("inv_freq", inv_freq, persistent=False)
  72. # For BC we register cos and sin cached
  73. self.max_seq_len_cached = max_position_embeddings
  74. @torch.no_grad()
  75. def forward(self, x, position_ids):
  76. # x: [bs, num_attention_heads, seq_len, head_size]
  77. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  78. position_ids_expanded = position_ids[:, None, :].float()
  79. # Force float32 since bfloat16 loses precision on long contexts
  80. # See https://github.com/huggingface/transformers/pull/29285
  81. device_type = x.device.type
  82. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  83. with torch.autocast(device_type=device_type, enabled=False):
  84. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  85. emb = torch.cat((freqs, freqs), dim=-1)
  86. cos = emb.cos()
  87. sin = emb.sin()
  88. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  89. # copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Olmo
  90. # TODO(joao): add me back asap :)
  91. class OlmoLinearScalingRotaryEmbedding(OlmoRotaryEmbedding):
  92. """OlmoRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
  93. def forward(self, x, position_ids):
  94. # difference to the original RoPE: a scaling factor is aplied to the position ids
  95. position_ids = position_ids.float() / self.scaling_factor
  96. cos, sin = super().forward(x, position_ids)
  97. return cos, sin
  98. # copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Olmo
  99. # TODO(joao): add me back asap :)
  100. class OlmoDynamicNTKScalingRotaryEmbedding(OlmoRotaryEmbedding):
  101. """OlmoRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
  102. def forward(self, x, position_ids):
  103. # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
  104. seq_len = torch.max(position_ids) + 1
  105. if seq_len > self.max_position_embeddings:
  106. base = self.base * (
  107. (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
  108. ) ** (self.dim / (self.dim - 2))
  109. inv_freq = 1.0 / (
  110. base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)
  111. )
  112. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
  113. cos, sin = super().forward(x, position_ids)
  114. return cos, sin
  115. # Copied from transformers.models.llama.modeling_llama.rotate_half
  116. def rotate_half(x):
  117. """Rotates half the hidden dims of the input."""
  118. x1 = x[..., : x.shape[-1] // 2]
  119. x2 = x[..., x.shape[-1] // 2 :]
  120. return torch.cat((-x2, x1), dim=-1)
  121. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  122. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  123. """Applies Rotary Position Embedding to the query and key tensors.
  124. Args:
  125. q (`torch.Tensor`): The query tensor.
  126. k (`torch.Tensor`): The key tensor.
  127. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  128. sin (`torch.Tensor`): The sine part of the rotary embedding.
  129. position_ids (`torch.Tensor`, *optional*):
  130. Deprecated and unused.
  131. unsqueeze_dim (`int`, *optional*, defaults to 1):
  132. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  133. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  134. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  135. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  136. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  137. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  138. Returns:
  139. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  140. """
  141. cos = cos.unsqueeze(unsqueeze_dim)
  142. sin = sin.unsqueeze(unsqueeze_dim)
  143. q_embed = (q * cos) + (rotate_half(q) * sin)
  144. k_embed = (k * cos) + (rotate_half(k) * sin)
  145. return q_embed, k_embed
  146. class OlmoMLP(nn.Module):
  147. def __init__(self, config):
  148. super().__init__()
  149. self.config = config
  150. self.hidden_size = config.hidden_size
  151. self.intermediate_size = config.intermediate_size
  152. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  153. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  154. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  155. self.act_fn = ACT2FN[config.hidden_act]
  156. def forward(self, x):
  157. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  158. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  159. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  160. """
  161. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  162. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  163. """
  164. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  165. if n_rep == 1:
  166. return hidden_states
  167. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  168. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  169. class OlmoAttention(nn.Module):
  170. """Multi-headed attention from 'Attention Is All You Need' paper"""
  171. # copied from transformers.models.llama.modeling_llama.LlamaAttention.__init__ with Llama->Olmo
  172. # TODO(joao): add me back asap :)
  173. def __init__(self, config: OlmoConfig, layer_idx: Optional[int] = None):
  174. super().__init__()
  175. self.config = config
  176. self.layer_idx = layer_idx
  177. if layer_idx is None:
  178. logger.warning_once(
  179. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  180. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  181. "when creating this class."
  182. )
  183. self.attention_dropout = config.attention_dropout
  184. self.hidden_size = config.hidden_size
  185. self.num_heads = config.num_attention_heads
  186. self.head_dim = self.hidden_size // self.num_heads
  187. self.num_key_value_heads = config.num_key_value_heads
  188. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  189. self.max_position_embeddings = config.max_position_embeddings
  190. self.rope_theta = config.rope_theta
  191. self.is_causal = True
  192. if (self.head_dim * self.num_heads) != self.hidden_size:
  193. raise ValueError(
  194. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  195. f" and `num_heads`: {self.num_heads})."
  196. )
  197. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  198. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  199. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  200. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  201. self._init_rope()
  202. def _init_rope(self):
  203. if self.config.rope_scaling is None:
  204. self.rotary_emb = OlmoRotaryEmbedding(
  205. self.head_dim,
  206. max_position_embeddings=self.max_position_embeddings,
  207. base=self.rope_theta,
  208. )
  209. else:
  210. scaling_type = self.config.rope_scaling["type"]
  211. scaling_factor = self.config.rope_scaling["factor"]
  212. if scaling_type == "linear":
  213. self.rotary_emb = OlmoLinearScalingRotaryEmbedding(
  214. self.head_dim,
  215. max_position_embeddings=self.max_position_embeddings,
  216. scaling_factor=scaling_factor,
  217. base=self.rope_theta,
  218. )
  219. elif scaling_type == "dynamic":
  220. self.rotary_emb = OlmoDynamicNTKScalingRotaryEmbedding(
  221. self.head_dim,
  222. max_position_embeddings=self.max_position_embeddings,
  223. scaling_factor=scaling_factor,
  224. base=self.rope_theta,
  225. )
  226. else:
  227. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  228. def forward(
  229. self,
  230. hidden_states: torch.Tensor,
  231. attention_mask: Optional[torch.Tensor] = None,
  232. position_ids: Optional[torch.LongTensor] = None,
  233. past_key_value: Optional[Cache] = None,
  234. output_attentions: bool = False,
  235. use_cache: bool = False,
  236. cache_position: Optional[torch.LongTensor] = None,
  237. **kwargs,
  238. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  239. bsz, q_len, _ = hidden_states.size()
  240. query_states = self.q_proj(hidden_states)
  241. key_states = self.k_proj(hidden_states)
  242. value_states = self.v_proj(hidden_states)
  243. if self.config.clip_qkv is not None:
  244. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  245. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  246. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  247. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  248. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  249. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  250. cos, sin = self.rotary_emb(value_states, position_ids)
  251. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  252. if past_key_value is not None:
  253. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  254. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  255. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  256. key_states = repeat_kv(key_states, self.num_key_value_groups)
  257. value_states = repeat_kv(value_states, self.num_key_value_groups)
  258. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  259. if attention_mask is not None: # no matter the length, we just slice it
  260. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  261. attn_weights = attn_weights + causal_mask
  262. # upcast attention to fp32
  263. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
  264. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  265. attn_output = torch.matmul(attn_weights, value_states)
  266. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  267. raise ValueError(
  268. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  269. f" {attn_output.size()}"
  270. )
  271. attn_output = attn_output.transpose(1, 2).contiguous()
  272. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  273. attn_output = self.o_proj(attn_output)
  274. if not output_attentions:
  275. attn_weights = None
  276. return attn_output, attn_weights, past_key_value
  277. class OlmoFlashAttention2(OlmoAttention):
  278. """
  279. OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays
  280. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  281. flash attention and deal with padding tokens in case the input contains any of them.
  282. """
  283. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  284. def __init__(self, *args, **kwargs):
  285. super().__init__(*args, **kwargs)
  286. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  287. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  288. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  289. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  290. def forward(
  291. self,
  292. hidden_states: torch.Tensor,
  293. attention_mask: Optional[torch.LongTensor] = None,
  294. position_ids: Optional[torch.LongTensor] = None,
  295. past_key_value: Optional[Cache] = None,
  296. output_attentions: bool = False,
  297. use_cache: bool = False,
  298. cache_position: Optional[torch.LongTensor] = None,
  299. **kwargs,
  300. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  301. output_attentions = False
  302. bsz, q_len, _ = hidden_states.size()
  303. query_states = self.q_proj(hidden_states)
  304. key_states = self.k_proj(hidden_states)
  305. value_states = self.v_proj(hidden_states)
  306. if self.config.clip_qkv is not None:
  307. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  308. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  309. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  310. # Flash attention requires the input to have the shape
  311. # batch_size x seq_length x head_dim x hidden_dim
  312. # therefore we just need to keep the original shape
  313. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  314. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  315. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  316. cos, sin = self.rotary_emb(value_states, position_ids)
  317. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  318. if past_key_value is not None:
  319. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  320. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  321. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  322. # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
  323. # to be able to avoid many of these transpose/reshape/view.
  324. query_states = query_states.transpose(1, 2)
  325. key_states = key_states.transpose(1, 2)
  326. value_states = value_states.transpose(1, 2)
  327. dropout_rate = self.attention_dropout if self.training else 0.0
  328. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  329. # therefore the input hidden states gets silently casted in float32. Hence, we need
  330. # cast them back in the correct dtype just to be sure everything works as expected.
  331. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  332. # in fp32. (OlmoRMSNorm handles it correctly)
  333. input_dtype = query_states.dtype
  334. if input_dtype == torch.float32:
  335. if torch.is_autocast_enabled():
  336. target_dtype = torch.get_autocast_gpu_dtype()
  337. # Handle the case where the model is quantized
  338. elif hasattr(self.config, "_pre_quantization_dtype"):
  339. target_dtype = self.config._pre_quantization_dtype
  340. else:
  341. target_dtype = self.q_proj.weight.dtype
  342. logger.warning_once(
  343. f"The input hidden states seems to be silently casted in float32, this might be related to"
  344. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  345. f" {target_dtype}."
  346. )
  347. query_states = query_states.to(target_dtype)
  348. key_states = key_states.to(target_dtype)
  349. value_states = value_states.to(target_dtype)
  350. attn_output = _flash_attention_forward(
  351. query_states,
  352. key_states,
  353. value_states,
  354. attention_mask,
  355. q_len,
  356. position_ids=position_ids,
  357. dropout=dropout_rate,
  358. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  359. is_causal=self.is_causal,
  360. )
  361. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  362. attn_output = self.o_proj(attn_output)
  363. if not output_attentions:
  364. attn_weights = None
  365. return attn_output, attn_weights, past_key_value
  366. class OlmoSdpaAttention(OlmoAttention):
  367. """
  368. OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  369. `OlmoAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  370. SDPA API.
  371. """
  372. # Adapted from OlmoAttention.forward
  373. def forward(
  374. self,
  375. hidden_states: torch.Tensor,
  376. attention_mask: Optional[torch.Tensor] = None,
  377. position_ids: Optional[torch.LongTensor] = None,
  378. past_key_value: Optional[Cache] = None,
  379. output_attentions: bool = False,
  380. use_cache: bool = False,
  381. cache_position: Optional[torch.LongTensor] = None,
  382. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  383. if output_attentions:
  384. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  385. logger.warning_once(
  386. "OlmoModel is using OlmoSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  387. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  388. )
  389. return super().forward(
  390. hidden_states=hidden_states,
  391. attention_mask=attention_mask,
  392. position_ids=position_ids,
  393. past_key_value=past_key_value,
  394. output_attentions=output_attentions,
  395. use_cache=use_cache,
  396. cache_position=cache_position,
  397. )
  398. bsz, q_len, _ = hidden_states.size()
  399. query_states = self.q_proj(hidden_states)
  400. key_states = self.k_proj(hidden_states)
  401. value_states = self.v_proj(hidden_states)
  402. if self.config.clip_qkv is not None:
  403. query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  404. key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  405. value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
  406. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  407. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  408. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  409. cos, sin = self.rotary_emb(value_states, position_ids)
  410. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  411. if past_key_value is not None:
  412. # sin and cos are specific to RoPE models; cache_position needed for the static cache
  413. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  414. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  415. key_states = repeat_kv(key_states, self.num_key_value_groups)
  416. value_states = repeat_kv(value_states, self.num_key_value_groups)
  417. causal_mask = attention_mask
  418. # if attention_mask is not None and cache_position is not None:
  419. if attention_mask is not None:
  420. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  421. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  422. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  423. if query_states.device.type == "cuda" and causal_mask is not None:
  424. query_states = query_states.contiguous()
  425. key_states = key_states.contiguous()
  426. value_states = value_states.contiguous()
  427. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  428. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  429. is_causal = True if causal_mask is None and q_len > 1 else False
  430. attn_output = torch.nn.functional.scaled_dot_product_attention(
  431. query_states,
  432. key_states,
  433. value_states,
  434. attn_mask=causal_mask,
  435. dropout_p=self.attention_dropout if self.training else 0.0,
  436. is_causal=is_causal,
  437. )
  438. attn_output = attn_output.transpose(1, 2).contiguous()
  439. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  440. attn_output = self.o_proj(attn_output)
  441. return attn_output, None, past_key_value
  442. OLMO_ATTENTION_CLASSES = {
  443. "eager": OlmoAttention,
  444. "flash_attention_2": OlmoFlashAttention2,
  445. "sdpa": OlmoSdpaAttention,
  446. }
  447. class OlmoDecoderLayer(nn.Module):
  448. def __init__(self, config: OlmoConfig, layer_idx: int):
  449. super().__init__()
  450. self.hidden_size = config.hidden_size
  451. self.self_attn = OLMO_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
  452. self.mlp = OlmoMLP(config)
  453. self.input_layernorm = OlmoLayerNorm(config.hidden_size)
  454. self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size)
  455. # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer.forward
  456. # TODO(joao): add me back asap :)
  457. def forward(
  458. self,
  459. hidden_states: torch.Tensor,
  460. attention_mask: Optional[torch.Tensor] = None,
  461. position_ids: Optional[torch.LongTensor] = None,
  462. past_key_value: Optional[Cache] = None,
  463. output_attentions: Optional[bool] = False,
  464. use_cache: Optional[bool] = False,
  465. cache_position: Optional[torch.LongTensor] = None,
  466. **kwargs,
  467. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  468. """
  469. Args:
  470. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  471. attention_mask (`torch.FloatTensor`, *optional*):
  472. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  473. query_sequence_length, key_sequence_length)` if default attention is used.
  474. output_attentions (`bool`, *optional*):
  475. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  476. returned tensors for more detail.
  477. use_cache (`bool`, *optional*):
  478. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  479. (see `past_key_values`).
  480. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  481. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  482. Indices depicting the position of the input sequence tokens in the sequence
  483. kwargs (`dict`, *optional*):
  484. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  485. into the model
  486. """
  487. residual = hidden_states
  488. hidden_states = self.input_layernorm(hidden_states)
  489. # Self Attention
  490. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  491. hidden_states=hidden_states,
  492. attention_mask=attention_mask,
  493. position_ids=position_ids,
  494. past_key_value=past_key_value,
  495. output_attentions=output_attentions,
  496. use_cache=use_cache,
  497. cache_position=cache_position,
  498. **kwargs,
  499. )
  500. hidden_states = residual + hidden_states
  501. # Fully Connected
  502. residual = hidden_states
  503. hidden_states = self.post_attention_layernorm(hidden_states)
  504. hidden_states = self.mlp(hidden_states)
  505. hidden_states = residual + hidden_states
  506. outputs = (hidden_states,)
  507. if output_attentions:
  508. outputs += (self_attn_weights,)
  509. if use_cache:
  510. outputs += (present_key_value,)
  511. return outputs
  512. OLMO_START_DOCSTRING = r"""
  513. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  514. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  515. etc.)
  516. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  517. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  518. and behavior.
  519. Parameters:
  520. config ([`OlmoConfig`]):
  521. Model configuration class with all the parameters of the model. Initializing with a config file does not
  522. load the weights associated with the model, only the configuration. Check out the
  523. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  524. """
  525. @add_start_docstrings(
  526. "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
  527. OLMO_START_DOCSTRING,
  528. )
  529. # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Olmo
  530. class OlmoPreTrainedModel(PreTrainedModel):
  531. config_class = OlmoConfig
  532. base_model_prefix = "model"
  533. supports_gradient_checkpointing = True
  534. _no_split_modules = ["OlmoDecoderLayer"]
  535. _skip_keys_device_placement = ["past_key_values"]
  536. _supports_flash_attn_2 = True
  537. _supports_sdpa = True
  538. _supports_cache_class = True
  539. _supports_quantized_cache = True
  540. _supports_static_cache = True
  541. def _init_weights(self, module):
  542. std = self.config.initializer_range
  543. if isinstance(module, nn.Linear):
  544. module.weight.data.normal_(mean=0.0, std=std)
  545. if module.bias is not None:
  546. module.bias.data.zero_()
  547. elif isinstance(module, nn.Embedding):
  548. module.weight.data.normal_(mean=0.0, std=std)
  549. if module.padding_idx is not None:
  550. module.weight.data[module.padding_idx].zero_()
  551. OLMO_INPUTS_DOCSTRING = r"""
  552. Args:
  553. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  554. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  555. it.
  556. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  557. [`PreTrainedTokenizer.__call__`] for details.
  558. [What are input IDs?](../glossary#input-ids)
  559. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  560. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  561. - 1 for tokens that are **not masked**,
  562. - 0 for tokens that are **masked**.
  563. [What are attention masks?](../glossary#attention-mask)
  564. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  565. [`PreTrainedTokenizer.__call__`] for details.
  566. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  567. `past_key_values`).
  568. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  569. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  570. information on the default strategy.
  571. - 1 indicates the head is **not masked**,
  572. - 0 indicates the head is **masked**.
  573. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  574. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  575. config.n_positions - 1]`.
  576. [What are position IDs?](../glossary#position-ids)
  577. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  578. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  579. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  580. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  581. Two formats are allowed:
  582. - a [`~cache_utils.Cache`] instance, see our
  583. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  584. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  585. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  586. cache format.
  587. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  588. legacy cache format will be returned.
  589. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  590. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  591. of shape `(batch_size, sequence_length)`.
  592. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  593. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  594. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  595. model's internal embedding lookup matrix.
  596. use_cache (`bool`, *optional*):
  597. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  598. `past_key_values`).
  599. output_attentions (`bool`, *optional*):
  600. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  601. tensors for more detail.
  602. output_hidden_states (`bool`, *optional*):
  603. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  604. more detail.
  605. return_dict (`bool`, *optional*):
  606. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  607. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  608. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  609. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  610. the complete sequence length.
  611. """
  612. @add_start_docstrings(
  613. "The bare Olmo Model outputting raw hidden-states without any specific head on top.",
  614. OLMO_START_DOCSTRING,
  615. )
  616. class OlmoModel(OlmoPreTrainedModel):
  617. """
  618. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OlmoDecoderLayer`]
  619. Args:
  620. config: OlmoConfig
  621. """
  622. def __init__(self, config: OlmoConfig):
  623. super().__init__(config)
  624. self.padding_idx = config.pad_token_id
  625. self.vocab_size = config.vocab_size
  626. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  627. self.layers = nn.ModuleList(
  628. [OlmoDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  629. )
  630. self.norm = OlmoLayerNorm(config.hidden_size)
  631. self.gradient_checkpointing = False
  632. # Initialize weights and apply final processing
  633. self.post_init()
  634. def get_input_embeddings(self):
  635. return self.embed_tokens
  636. def set_input_embeddings(self, value):
  637. self.embed_tokens = value
  638. @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
  639. # copied from transformers.models.llama.modeling_llama.LlamaModel.forward
  640. # TODO(joao): add me back asap :)
  641. def forward(
  642. self,
  643. input_ids: torch.LongTensor = None,
  644. attention_mask: Optional[torch.Tensor] = None,
  645. position_ids: Optional[torch.LongTensor] = None,
  646. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  647. inputs_embeds: Optional[torch.FloatTensor] = None,
  648. use_cache: Optional[bool] = None,
  649. output_attentions: Optional[bool] = None,
  650. output_hidden_states: Optional[bool] = None,
  651. return_dict: Optional[bool] = None,
  652. cache_position: Optional[torch.LongTensor] = None,
  653. ) -> Union[Tuple, BaseModelOutputWithPast]:
  654. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  655. output_hidden_states = (
  656. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  657. )
  658. use_cache = use_cache if use_cache is not None else self.config.use_cache
  659. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  660. if (input_ids is None) ^ (inputs_embeds is not None):
  661. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  662. if self.gradient_checkpointing and self.training and use_cache:
  663. logger.warning_once(
  664. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  665. )
  666. use_cache = False
  667. if inputs_embeds is None:
  668. inputs_embeds = self.embed_tokens(input_ids)
  669. # kept for BC (non `Cache` `past_key_values` inputs)
  670. return_legacy_cache = False
  671. if use_cache and not isinstance(past_key_values, Cache):
  672. return_legacy_cache = True
  673. if past_key_values is None:
  674. past_key_values = DynamicCache()
  675. else:
  676. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  677. logger.warning_once(
  678. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  679. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  680. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  681. )
  682. if cache_position is None:
  683. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  684. cache_position = torch.arange(
  685. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  686. )
  687. if position_ids is None:
  688. position_ids = cache_position.unsqueeze(0)
  689. causal_mask = self._update_causal_mask(
  690. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  691. )
  692. # embed positions
  693. hidden_states = inputs_embeds
  694. # decoder layers
  695. all_hidden_states = () if output_hidden_states else None
  696. all_self_attns = () if output_attentions else None
  697. next_decoder_cache = None
  698. for decoder_layer in self.layers:
  699. if output_hidden_states:
  700. all_hidden_states += (hidden_states,)
  701. if self.gradient_checkpointing and self.training:
  702. layer_outputs = self._gradient_checkpointing_func(
  703. decoder_layer.__call__,
  704. hidden_states,
  705. causal_mask,
  706. position_ids,
  707. past_key_values,
  708. output_attentions,
  709. use_cache,
  710. cache_position,
  711. )
  712. else:
  713. layer_outputs = decoder_layer(
  714. hidden_states,
  715. attention_mask=causal_mask,
  716. position_ids=position_ids,
  717. past_key_value=past_key_values,
  718. output_attentions=output_attentions,
  719. use_cache=use_cache,
  720. cache_position=cache_position,
  721. )
  722. hidden_states = layer_outputs[0]
  723. if use_cache:
  724. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  725. if output_attentions:
  726. all_self_attns += (layer_outputs[1],)
  727. hidden_states = self.norm(hidden_states)
  728. # add hidden states from the last decoder layer
  729. if output_hidden_states:
  730. all_hidden_states += (hidden_states,)
  731. next_cache = next_decoder_cache if use_cache else None
  732. if return_legacy_cache:
  733. next_cache = next_cache.to_legacy_cache()
  734. if not return_dict:
  735. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  736. return BaseModelOutputWithPast(
  737. last_hidden_state=hidden_states,
  738. past_key_values=next_cache,
  739. hidden_states=all_hidden_states,
  740. attentions=all_self_attns,
  741. )
  742. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  743. def _update_causal_mask(
  744. self,
  745. attention_mask: torch.Tensor,
  746. input_tensor: torch.Tensor,
  747. cache_position: torch.Tensor,
  748. past_key_values: Cache,
  749. output_attentions: bool,
  750. ):
  751. if self.config._attn_implementation == "flash_attention_2":
  752. if attention_mask is not None and 0.0 in attention_mask:
  753. return attention_mask
  754. return None
  755. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  756. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  757. # to infer the attention mask.
  758. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  759. using_static_cache = isinstance(past_key_values, StaticCache)
  760. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  761. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  762. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  763. attention_mask,
  764. inputs_embeds=input_tensor,
  765. past_key_values_length=past_seen_tokens,
  766. is_training=self.training,
  767. ):
  768. return None
  769. dtype, device = input_tensor.dtype, input_tensor.device
  770. sequence_length = input_tensor.shape[1]
  771. if using_static_cache:
  772. target_length = past_key_values.get_max_cache_shape()
  773. else:
  774. target_length = (
  775. attention_mask.shape[-1]
  776. if isinstance(attention_mask, torch.Tensor)
  777. else past_seen_tokens + sequence_length + 1
  778. )
  779. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  780. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  781. attention_mask,
  782. sequence_length=sequence_length,
  783. target_length=target_length,
  784. dtype=dtype,
  785. device=device,
  786. cache_position=cache_position,
  787. batch_size=input_tensor.shape[0],
  788. )
  789. if (
  790. self.config._attn_implementation == "sdpa"
  791. and attention_mask is not None
  792. and attention_mask.device.type == "cuda"
  793. and not output_attentions
  794. ):
  795. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  796. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  797. # Details: https://github.com/pytorch/pytorch/issues/110213
  798. min_dtype = torch.finfo(dtype).min
  799. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  800. return causal_mask
  801. @staticmethod
  802. # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
  803. def _prepare_4d_causal_attention_mask_with_cache_position(
  804. attention_mask: torch.Tensor,
  805. sequence_length: int,
  806. target_length: int,
  807. dtype: torch.dtype,
  808. device: torch.device,
  809. cache_position: torch.Tensor,
  810. batch_size: int,
  811. **kwargs,
  812. ):
  813. """
  814. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  815. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  816. Args:
  817. attention_mask (`torch.Tensor`):
  818. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  819. `(batch_size, 1, query_length, key_value_length)`.
  820. sequence_length (`int`):
  821. The sequence length being processed.
  822. target_length (`int`):
  823. The target length: when generating with static cache, the mask should be as long as the static cache,
  824. to account for the 0 padding, the part of the cache that is not filled yet.
  825. dtype (`torch.dtype`):
  826. The dtype to use for the 4D attention mask.
  827. device (`torch.device`):
  828. The device to plcae the 4D attention mask on.
  829. cache_position (`torch.Tensor`):
  830. Indices depicting the position of the input sequence tokens in the sequence.
  831. batch_size (`torch.Tensor`):
  832. Batch size.
  833. """
  834. if attention_mask is not None and attention_mask.dim() == 4:
  835. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  836. causal_mask = attention_mask
  837. else:
  838. min_dtype = torch.finfo(dtype).min
  839. causal_mask = torch.full(
  840. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  841. )
  842. if sequence_length != 1:
  843. causal_mask = torch.triu(causal_mask, diagonal=1)
  844. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  845. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  846. if attention_mask is not None:
  847. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  848. mask_length = attention_mask.shape[-1]
  849. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  850. padding_mask = padding_mask == 0
  851. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  852. padding_mask, min_dtype
  853. )
  854. return causal_mask
  855. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo
  856. class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin):
  857. _tied_weights_keys = ["lm_head.weight"]
  858. def __init__(self, config):
  859. super().__init__(config)
  860. self.model = OlmoModel(config)
  861. self.vocab_size = config.vocab_size
  862. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  863. # Initialize weights and apply final processing
  864. self.post_init()
  865. def get_input_embeddings(self):
  866. return self.model.embed_tokens
  867. def set_input_embeddings(self, value):
  868. self.model.embed_tokens = value
  869. def get_output_embeddings(self):
  870. return self.lm_head
  871. def set_output_embeddings(self, new_embeddings):
  872. self.lm_head = new_embeddings
  873. def set_decoder(self, decoder):
  874. self.model = decoder
  875. def get_decoder(self):
  876. return self.model
  877. @add_start_docstrings_to_model_forward(OLMO_INPUTS_DOCSTRING)
  878. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  879. # Ignore copy
  880. def forward(
  881. self,
  882. input_ids: torch.LongTensor = None,
  883. attention_mask: Optional[torch.Tensor] = None,
  884. position_ids: Optional[torch.LongTensor] = None,
  885. past_key_values: Optional[List[torch.FloatTensor]] = None,
  886. inputs_embeds: Optional[torch.FloatTensor] = None,
  887. labels: Optional[torch.LongTensor] = None,
  888. use_cache: Optional[bool] = None,
  889. output_attentions: Optional[bool] = None,
  890. output_hidden_states: Optional[bool] = None,
  891. return_dict: Optional[bool] = None,
  892. cache_position: Optional[torch.LongTensor] = None,
  893. num_logits_to_keep: int = 0,
  894. **loss_kwargs,
  895. ) -> Union[Tuple, CausalLMOutputWithPast]:
  896. r"""
  897. Args:
  898. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  899. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  900. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  901. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  902. num_logits_to_keep (`int`, *optional*):
  903. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  904. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  905. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  906. Returns:
  907. Example:
  908. ```python
  909. >>> from transformers import AutoTokenizer, OlmoForCausalLM
  910. >>> model = OlmoForCausalLM.from_pretrained("allenai/OLMo-1B-hf")
  911. >>> tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-hf")
  912. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  913. >>> inputs = tokenizer(prompt, return_tensors="pt")
  914. >>> # Generate
  915. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  916. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  917. 'Hey, are you conscious? Can you talk to me?\nI’m not sure if you’re conscious of this, but I’m'
  918. ```
  919. """
  920. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  921. output_hidden_states = (
  922. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  923. )
  924. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  925. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  926. outputs = self.model(
  927. input_ids=input_ids,
  928. attention_mask=attention_mask,
  929. position_ids=position_ids,
  930. past_key_values=past_key_values,
  931. inputs_embeds=inputs_embeds,
  932. use_cache=use_cache,
  933. output_attentions=output_attentions,
  934. output_hidden_states=output_hidden_states,
  935. return_dict=return_dict,
  936. cache_position=cache_position,
  937. )
  938. hidden_states = outputs[0]
  939. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  940. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  941. loss = None
  942. if labels is not None:
  943. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  944. if not return_dict:
  945. output = (logits,) + outputs[1:]
  946. return (loss,) + output if loss is not None else output
  947. return CausalLMOutputWithPast(
  948. loss=loss,
  949. logits=logits,
  950. past_key_values=outputs.past_key_values,
  951. hidden_states=outputs.hidden_states,
  952. attentions=outputs.attentions,
  953. )