modeling_mamba2.py 50 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083
  1. # coding=utf-8
  2. # Copyright 2024 state-spaces/mamba2 org and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MAMBA2 model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...generation import GenerationMixin
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import (
  27. ModelOutput,
  28. add_code_sample_docstrings,
  29. add_start_docstrings,
  30. add_start_docstrings_to_model_forward,
  31. logging,
  32. )
  33. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
  34. from .configuration_mamba2 import Mamba2Config
  35. logger = logging.get_logger(__name__)
  36. if is_mamba_2_ssm_available():
  37. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  38. from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
  39. else:
  40. selective_state_update = None
  41. if is_causal_conv1d_available():
  42. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  43. else:
  44. causal_conv1d_update, causal_conv1d_fn = None, None
  45. is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
  46. _CHECKPOINT_FOR_DOC = "mistralai/mamba-codestral-7B-v0.1"
  47. _CONFIG_FOR_DOC = "Mamba2Config"
  48. # Helper methods for segment sum computation
  49. def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
  50. """
  51. Padding x tensor with `pad_size` on the seq_len dim (dim=1)
  52. Assumes that we only have tensors of either size 4 or 3
  53. """
  54. pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
  55. return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
  56. def reshape_into_chunks(input_tensor, pad_size, chunk_size):
  57. """
  58. Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
  59. simultaneously splitting it into chunk sequences.
  60. Assumes that we only have tensors of either size 4 or 3
  61. """
  62. # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
  63. input_tensor = pad_tensor_by_size(input_tensor, pad_size)
  64. if len(input_tensor.shape) == 3:
  65. # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
  66. return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
  67. else:
  68. # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
  69. return input_tensor.reshape(
  70. input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
  71. )
  72. def segment_sum(input_tensor):
  73. """
  74. More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
  75. """
  76. chunk_size = input_tensor.size(-1)
  77. # 1. expand input tensor to have an additional dimension and repeat along that dimension
  78. # [..., chunk_size] -> [..., chunk_size, chunk_size]
  79. input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
  80. # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
  81. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
  82. input_tensor = input_tensor.masked_fill(~mask, 0)
  83. # 3. compute actual cumsum
  84. tensor_segsum = torch.cumsum(input_tensor, dim=-2)
  85. # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
  86. mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
  87. tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
  88. return tensor_segsum
  89. class Mamba2Cache:
  90. """
  91. Arguments:
  92. config: Mamba2Config
  93. batch_size: int
  94. dtype: torch.dtype
  95. device: torch.device
  96. Attributes:
  97. seqlen_offset: int
  98. dtype: torch.dtype
  99. conv_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, conv_kernel_size]
  100. ssm_states: Dict[int, torch.Tensor] # layer_idx -> [batch_size, intermediate_size, ssm_state_size]
  101. """
  102. def __init__(
  103. self, config: Mamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
  104. ):
  105. self.seqlen_offset = 0
  106. self.dtype = dtype
  107. self.conv_kernel_size = config.conv_kernel
  108. self.intermediate_size = int(config.expand * config.hidden_size)
  109. self.conv_states = {
  110. i: torch.zeros(
  111. batch_size,
  112. self.intermediate_size + 2 * config.n_groups * config.state_size,
  113. self.conv_kernel_size,
  114. device=device,
  115. dtype=dtype,
  116. )
  117. for i in range(config.num_hidden_layers)
  118. }
  119. self.ssm_states = {
  120. i: torch.zeros(
  121. batch_size, config.num_heads, config.head_dim, config.state_size, device=device, dtype=dtype
  122. )
  123. for i in range(config.num_hidden_layers)
  124. }
  125. self.activation = config.hidden_act
  126. self.act = ACT2FN[config.hidden_act]
  127. def update_conv_state(
  128. self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
  129. ) -> torch.Tensor:
  130. conv_state = self.conv_states[layer_idx]
  131. cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
  132. conv_state = conv_state.roll(shifts=-1, dims=-1)
  133. conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
  134. self.conv_states[layer_idx].zero_()
  135. self.conv_states[layer_idx] += conv_state
  136. return self.conv_states[layer_idx]
  137. def reset(self):
  138. self.conv_states.zero_()
  139. self.ssm_states.zero_()
  140. class MambaRMSNormGated(torch.nn.Module):
  141. def __init__(self, hidden_size, eps=1e-6):
  142. super().__init__()
  143. self.weight = nn.Parameter(torch.ones(hidden_size))
  144. self.variance_epsilon = eps
  145. def forward(self, hidden_states, gate=None):
  146. input_dtype = hidden_states.dtype
  147. hidden_states = hidden_states.to(torch.float32)
  148. if gate is not None:
  149. hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
  150. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  151. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  152. return self.weight * hidden_states.to(input_dtype)
  153. class Mamba2Mixer(nn.Module):
  154. """
  155. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  156. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  157. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  158. and is why Mamba is called **selective** state spaces)
  159. """
  160. def __init__(self, config: Mamba2Config, layer_idx: int):
  161. super().__init__()
  162. self.num_heads = config.num_heads
  163. self.hidden_size = config.hidden_size
  164. self.ssm_state_size = config.state_size
  165. self.conv_kernel_size = config.conv_kernel
  166. self.intermediate_size = int(config.expand * self.hidden_size)
  167. self.time_step_rank = int(config.time_step_rank)
  168. self.layer_idx = layer_idx
  169. self.use_conv_bias = config.use_conv_bias
  170. self.activation = config.hidden_act
  171. self.act = ACT2FN[config.hidden_act]
  172. self.layer_norm_epsilon = config.layer_norm_epsilon
  173. self.rms_norm = config.rms_norm
  174. self.n_groups = config.n_groups
  175. self.head_dim = config.head_dim
  176. self.chunk_size = config.chunk_size
  177. self.time_step_limit = config.time_step_limit
  178. self.time_step_min = config.time_step_min
  179. self.time_step_max = config.time_step_max
  180. self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
  181. self.conv1d = nn.Conv1d(
  182. in_channels=self.conv_dim,
  183. out_channels=self.conv_dim,
  184. bias=config.use_conv_bias,
  185. kernel_size=config.conv_kernel,
  186. groups=self.conv_dim,
  187. padding=config.conv_kernel - 1,
  188. )
  189. # projection of the input hidden states
  190. projection_size = self.intermediate_size + self.conv_dim + self.num_heads
  191. self.in_proj = nn.Linear(
  192. self.hidden_size,
  193. projection_size,
  194. bias=config.use_bias,
  195. )
  196. # selective projection used to make dt, B and C input dependant
  197. # time step projection (discretization)
  198. # instantiate once and copy inv_dt in init_weights of PretrainedModel
  199. self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
  200. # S4D real initialization. These are not discretized!
  201. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  202. A = torch.arange(1, self.num_heads + 1)
  203. self.A_log = nn.Parameter(torch.log(A))
  204. self.A_log._no_weight_decay = True
  205. self.norm = MambaRMSNormGated(self.intermediate_size, eps=self.layer_norm_epsilon)
  206. self.D = nn.Parameter(torch.ones(self.num_heads))
  207. self.D._no_weight_decay = True
  208. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  209. self.use_bias = config.use_bias
  210. if not is_fast_path_available:
  211. logger.warning_once(
  212. "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
  213. " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
  214. " https://github.com/Dao-AILab/causal-conv1d"
  215. )
  216. def cuda_kernels_forward(
  217. self,
  218. hidden_states: torch.Tensor,
  219. cache_params: Optional[Mamba2Cache] = None,
  220. cache_position: Optional[torch.LongTensor] = None,
  221. attention_mask: Optional[torch.Tensor] = None,
  222. ):
  223. # set up dimensions for reshapes later
  224. batch_size, seq_len, _ = hidden_states.shape
  225. groups_time_state_size = self.n_groups * self.ssm_state_size
  226. d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
  227. # getting projected states from cache if it exists
  228. if cache_params is not None and cache_params.seqlen_offset > 0:
  229. in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
  230. d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
  231. split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
  232. _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
  233. hidden_states_B_C = causal_conv1d_update(
  234. hidden_states_B_C,
  235. cache_params.conv_states[self.layer_idx],
  236. self.conv1d.weight.squeeze(1),
  237. self.conv1d.bias,
  238. self.activation,
  239. )
  240. hidden_states, B, C = torch.split(
  241. hidden_states_B_C,
  242. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  243. dim=-1,
  244. )
  245. A = -torch.exp(self.A_log.float()) # (nheads,)
  246. A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  247. dt = dt[:, :, None].expand(-1, -1, self.head_dim)
  248. dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
  249. D = self.D[:, None, ...].expand(-1, self.head_dim)
  250. B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
  251. C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
  252. hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
  253. hidden_states = selective_state_update(
  254. cache_params.ssm_states[self.layer_idx],
  255. hidden_states_reshaped,
  256. dt,
  257. A,
  258. B,
  259. C,
  260. D,
  261. z=None,
  262. dt_bias=dt_bias,
  263. dt_softplus=True,
  264. )
  265. hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
  266. hidden_states = self.norm(hidden_states, gate)
  267. out = self.out_proj(hidden_states)[:, None, ...]
  268. # if no cache is found, calling the kernel
  269. else:
  270. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  271. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  272. dtype = hidden_states.dtype
  273. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  274. # 1. Gated MLP's linear projection
  275. projected_states = self.in_proj(hidden_states)
  276. A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
  277. dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
  278. if self.training and cache_params is None:
  279. out, ssm_state = mamba_split_conv1d_scan_combined(
  280. projected_states,
  281. self.conv1d.weight.squeeze(1),
  282. self.conv1d.bias,
  283. self.dt_bias,
  284. A,
  285. D=self.D,
  286. chunk_size=self.chunk_size,
  287. seq_idx=None, # was seq_idx
  288. activation=self.activation,
  289. rmsnorm_weight=self.norm.weight,
  290. rmsnorm_eps=self.norm.variance_epsilon,
  291. outproj_weight=self.out_proj.weight,
  292. outproj_bias=self.out_proj.bias,
  293. headdim=self.head_dim,
  294. ngroups=self.n_groups,
  295. norm_before_gate=False,
  296. return_final_states=True,
  297. **dt_limit_kwargs,
  298. )
  299. else:
  300. gate, hidden_states_B_C, time_step = torch.split(
  301. projected_states,
  302. [self.intermediate_size, self.conv_dim, self.num_heads],
  303. dim=-1,
  304. )
  305. # 1D Convolution
  306. if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
  307. hidden_states_B_C = self.act(
  308. self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
  309. ) # (B, L, self.d_inner + 2 * ngroups * d_state)
  310. else:
  311. hidden_states_B_C = causal_conv1d_fn(
  312. x=hidden_states_B_C.transpose(1, 2),
  313. weight=self.conv1d.weight.squeeze(1),
  314. bias=self.conv1d.bias,
  315. activation=self.activation,
  316. ).transpose(1, 2)[:, :seq_len]
  317. hidden_states, B, C = torch.split(
  318. hidden_states_B_C,
  319. [self.intermediate_size, groups_time_state_size, groups_time_state_size],
  320. dim=-1,
  321. )
  322. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  323. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  324. dtype = hidden_states.dtype
  325. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  326. scan_output, ssm_state = mamba_chunk_scan_combined(
  327. hidden_states.view(batch_size, seq_len, -1, self.head_dim),
  328. time_step,
  329. A,
  330. B.view(batch_size, seq_len, self.n_groups, -1),
  331. C.view(batch_size, seq_len, self.n_groups, -1),
  332. chunk_size=self.chunk_size,
  333. D=self.D,
  334. z=None,
  335. seq_idx=None,
  336. return_final_states=True,
  337. dt_bias=self.dt_bias,
  338. dt_softplus=True,
  339. **dt_limit_kwargs,
  340. )
  341. if ssm_state is not None and cache_params is not None:
  342. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  343. scan_output = scan_output.view(batch_size, seq_len, -1)
  344. # Multiply "gate" branch and apply extra normalization layer
  345. scan_output = self.norm(scan_output, gate)
  346. out = self.out_proj(scan_output)
  347. return out
  348. # fmt: off
  349. def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
  350. batch_size, seq_len, _ = input_states.shape
  351. dtype = input_states.dtype
  352. # Gated MLP's linear projection
  353. projected_states = self.in_proj(input_states.squeeze(1))
  354. d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
  355. _, _, gate, hidden_states, dt = projected_states.split(
  356. [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
  357. )
  358. # Convolution sequence transformation
  359. if cache_params is not None:
  360. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  361. ssm_state = ssm_state.to(hidden_states.device)
  362. if cache_params.seqlen_offset > 0:
  363. conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
  364. conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
  365. # handle batched generation - states are copied through
  366. conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
  367. cache_params.conv_states[self.layer_idx].copy_(conv_state)
  368. hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
  369. if self.use_conv_bias:
  370. hidden_states += self.conv1d.bias
  371. hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
  372. else:
  373. hidden_states = hidden_states.transpose(1,2)
  374. conv_state = nn.functional.pad(
  375. hidden_states,
  376. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  377. )
  378. cache_params.conv_states[self.layer_idx].copy_(conv_state)
  379. hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
  380. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  381. dtype = hidden_states.dtype
  382. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  383. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  384. else:
  385. ssm_state = torch.zeros(
  386. (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
  387. device=hidden_states.device, dtype=dtype
  388. )
  389. hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
  390. hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
  391. A = -torch.exp(self.A_log.float()) # [num_heads]
  392. if cache_params is not None and cache_params.seqlen_offset > 0:
  393. # Note: there is no need to pad parameter matrices here, as there is just one new token
  394. # for batched generation
  395. dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
  396. dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
  397. # [num_heads] -> [num_heads, head_dim]
  398. dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
  399. dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
  400. dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
  401. A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
  402. # [bsz, num_heads, head_dim, state_size]
  403. dA = torch.exp(dt[..., None] * A)
  404. # Discretize B
  405. # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
  406. # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
  407. B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
  408. B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
  409. B = B.reshape(batch_size, -1, B.shape[-1])
  410. # [bsz, num_heads, head_dim, state_size]
  411. dB = dt[..., None] * B[..., None, :]
  412. # Discretize x into dB
  413. # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
  414. hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
  415. dBx = dB * hidden_states[..., None]
  416. # State calculation
  417. cache_params.ssm_states[self.layer_idx].copy_(
  418. cache_params.ssm_states[self.layer_idx] * dA + dBx
  419. )
  420. # Subsequent output
  421. # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
  422. C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
  423. C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
  424. C = C.reshape(batch_size, -1, C.shape[-1])
  425. # [bsz, num_heads, head_dim]
  426. ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
  427. # Reshape ssm_states to merge the first two dimensions
  428. ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
  429. C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
  430. y = torch.bmm(ssm_states_reshaped, C_reshaped)
  431. y = y.view(batch_size, self.num_heads, self.head_dim)
  432. # D skip connection
  433. # [num_heads] -> [num_heads, head_dim]
  434. D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
  435. y = (y + hidden_states * D).to(y.dtype)
  436. # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
  437. y = y.reshape(batch_size, -1)[:, None, ...]
  438. else:
  439. # begin ssd naive implementation without einsums
  440. dt = nn.functional.softplus(dt + self.dt_bias)
  441. dt = torch.clamp(dt, self.time_step_min)
  442. hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
  443. B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  444. C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
  445. B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
  446. C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
  447. pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
  448. D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
  449. # Discretize x and A
  450. hidden_states = hidden_states * dt[..., None]
  451. A = A.to(hidden_states.dtype) * dt
  452. # Rearrange into blocks/chunks
  453. hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
  454. # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
  455. A = A.permute(0, 3, 1, 2)
  456. A_cumsum = torch.cumsum(A, dim=-1)
  457. # 1. Compute the output for each intra-chunk (diagonal blocks)
  458. # This is the analog of a causal mask
  459. L = torch.exp(segment_sum(A))
  460. # First, contraction of C and B to get G (attention-weights like)
  461. G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
  462. G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
  463. # Step 2: Compute M, equivalent to applying attention mask to weights
  464. M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
  465. M = M_intermediate.sum(dim=-1)
  466. # Step 3: Compute Y_diag (apply to values)
  467. Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
  468. # (right term of low-rank factorization of off-diagonal blocks; B terms)
  469. decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
  470. B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
  471. # permute back B * decay states
  472. states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
  473. if cache_params is not None and cache_params.seqlen_offset > 0:
  474. previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
  475. else:
  476. previous_states = torch.zeros_like(states[:, :1])
  477. states = torch.cat([previous_states, states], dim=1)
  478. decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
  479. states_permuted = states.permute(0, 2, 1, 3, 4)
  480. result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
  481. new_states = result.permute(0, 2, 1, 3, 4)
  482. states, ssm_state = new_states[:, :-1], new_states[:, -1]
  483. # Compute state -> output conversion per chunk
  484. # (left term of low-rank factorization of off-diagonal blocks; C terms)
  485. state_decay_out = torch.exp(A_cumsum)
  486. # compute Yoff
  487. C_times_states = (C[..., None, :] * states[:, :, None, ...])
  488. state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
  489. Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
  490. # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
  491. y = Y_diag + Y_off
  492. # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
  493. y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
  494. y = y + D_residual
  495. # Cutting off padded chunks
  496. if pad_size > 0:
  497. y = y[:, :seq_len, :, :]
  498. y = y.reshape(batch_size, seq_len, -1)
  499. if ssm_state is not None and cache_params is not None:
  500. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  501. scan_output = self.norm(y, gate)
  502. # end ssd naive
  503. # 4. Final linear projection
  504. contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
  505. return contextualized_states
  506. # fmt: on
  507. def forward(
  508. self,
  509. hidden_states,
  510. cache_params: Optional[Mamba2Cache] = None,
  511. cache_position: Optional[torch.LongTensor] = None,
  512. attention_mask: Optional[torch.Tensor] = None,
  513. ):
  514. if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
  515. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
  516. dtype = hidden_states.dtype
  517. if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
  518. # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
  519. hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
  520. return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
  521. class Mamba2RMSNorm(nn.Module):
  522. def __init__(self, hidden_size, eps=1e-6):
  523. """
  524. Mamba2RMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  525. """
  526. super().__init__()
  527. self.weight = nn.Parameter(torch.ones(hidden_size))
  528. self.variance_epsilon = eps
  529. def forward(self, hidden_states):
  530. input_dtype = hidden_states.dtype
  531. hidden_states = hidden_states.to(torch.float32)
  532. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  533. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  534. return self.weight * hidden_states.to(input_dtype)
  535. class Mamba2Block(nn.Module):
  536. def __init__(self, config, layer_idx):
  537. super().__init__()
  538. self.config = config
  539. self.layer_idx = layer_idx
  540. self.residual_in_fp32 = config.residual_in_fp32
  541. self.norm = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  542. self.mixer = Mamba2Mixer(config, layer_idx=layer_idx)
  543. def forward(
  544. self,
  545. hidden_states,
  546. cache_params: Optional[Mamba2Cache] = None,
  547. cache_position: Optional[torch.LongTensor] = None,
  548. attention_mask: Optional[torch.Tensor] = None,
  549. ):
  550. residual = hidden_states
  551. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  552. if self.residual_in_fp32:
  553. residual = residual.to(torch.float32)
  554. hidden_states = self.mixer(
  555. hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
  556. )
  557. hidden_states = residual + hidden_states
  558. return hidden_states
  559. class Mamba2PreTrainedModel(PreTrainedModel):
  560. """
  561. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  562. models.
  563. """
  564. config_class = Mamba2Config
  565. base_model_prefix = "backbone"
  566. _no_split_modules = ["Mamba2Block"]
  567. supports_gradient_checkpointing = True
  568. _is_stateful = True
  569. def _init_weights(self, module):
  570. """Initialize the weights."""
  571. if isinstance(module, Mamba2Mixer):
  572. module.A_log._no_weight_decay = True
  573. module.D._no_weight_decay = True
  574. dt = torch.exp(
  575. torch.rand(self.config.num_heads)
  576. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  577. + math.log(self.config.time_step_min)
  578. ).clamp(min=self.config.time_step_floor)
  579. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  580. inv_dt = dt + torch.log(-torch.expm1(-dt))
  581. with torch.no_grad():
  582. module.dt_bias.copy_(inv_dt)
  583. module.dt_bias._no_reinit = True
  584. if isinstance(module, nn.Linear):
  585. if module.bias is not None:
  586. if not getattr(module.bias, "_no_reinit", False):
  587. nn.init.zeros_(module.bias)
  588. elif isinstance(module, nn.Embedding):
  589. nn.init.normal_(module.weight, std=self.config.initializer_range)
  590. if self.config.rescale_prenorm_residual:
  591. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  592. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  593. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  594. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  595. #
  596. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  597. for name, p in module.named_parameters():
  598. if name in ["out_proj.weight"]:
  599. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  600. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  601. # We need to reinit p since this code could be called multiple times
  602. # Having just p *= scale would repeatedly scale it down
  603. nn.init.kaiming_uniform_(p, a=math.sqrt(5))
  604. with torch.no_grad():
  605. p /= math.sqrt(self.config.num_hidden_layers)
  606. @dataclass
  607. # Copied from transformers.models.mamba.modeling_mamba.MambaOutput with MAMBA->MAMBA2,Mamba->Mamba2
  608. class Mamba2Output(ModelOutput):
  609. """
  610. Class for the MAMBA2 model outputs.
  611. Args:
  612. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  613. Sequence of hidden-states at the output of the last layer of the model.
  614. cache_params (`Mamba2Cache`):
  615. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  616. avoid providing the old `input_ids`.
  617. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  618. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  619. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  620. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  621. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  622. """
  623. last_hidden_state: Optional[torch.FloatTensor] = None
  624. cache_params: Optional[Mamba2Cache] = None
  625. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  626. @dataclass
  627. # Copied from transformers.models.mamba.modeling_mamba.MambaCausalLMOutput with Mamba->Mamba2
  628. class Mamba2CausalLMOutput(ModelOutput):
  629. """
  630. Base class for causal language model (or autoregressive) outputs.
  631. Args:
  632. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  633. Language modeling loss (for next-token prediction).
  634. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  635. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  636. cache_params (`Mamba2Cache`):
  637. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  638. avoid providing the old `input_ids`.
  639. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  640. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  641. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  642. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  643. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  644. """
  645. loss: Optional[torch.FloatTensor] = None
  646. logits: Optional[torch.FloatTensor] = None
  647. cache_params: Optional[Mamba2Cache] = None
  648. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  649. MAMBA2_START_DOCSTRING = r"""
  650. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  651. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  652. etc.)
  653. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  654. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  655. and behavior.
  656. Parameters:
  657. config ([`Mamba2Config`]): Model configuration class with all the parameters of the model.
  658. Initializing with a config file does not load the weights associated with the model, only the
  659. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  660. """
  661. MAMBA2_INPUTS_DOCSTRING = r"""
  662. Args:
  663. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  664. Indices of input sequence tokens in the vocabulary.
  665. If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
  666. `input_ids`.
  667. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  668. [`PreTrainedTokenizer.__call__`] for details.
  669. [What are input IDs?](../glossary#input-ids)
  670. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  671. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  672. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  673. model's internal embedding lookup matrix.
  674. cache_params (`Mamba2Cache`, *optional*):
  675. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  676. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  677. use_cache (`bool`, *optional*):
  678. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  679. output_hidden_states (`bool`, *optional*):
  680. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  681. more detail.
  682. return_dict (`bool`, *optional*):
  683. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  684. """
  685. @add_start_docstrings(
  686. "The bare MAMBA2 Model transformer outputting raw hidden-states without any specific head on top.",
  687. MAMBA2_START_DOCSTRING,
  688. )
  689. class Mamba2Model(Mamba2PreTrainedModel):
  690. def __init__(self, config):
  691. super().__init__(config)
  692. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  693. self.layers = nn.ModuleList([Mamba2Block(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  694. self.gradient_checkpointing = False
  695. self.norm_f = Mamba2RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  696. # Initialize weights and apply final processing
  697. self._register_load_state_dict_pre_hook(self.load_hook)
  698. self.post_init()
  699. def load_hook(self, state_dict, prefix, *args):
  700. for k in state_dict:
  701. if "embedding." in k:
  702. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  703. break
  704. def get_input_embeddings(self):
  705. return self.embeddings
  706. def set_input_embeddings(self, new_embeddings):
  707. self.embeddings = new_embeddings
  708. @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
  709. @add_code_sample_docstrings(
  710. checkpoint=_CHECKPOINT_FOR_DOC,
  711. output_type=Mamba2Output,
  712. config_class=_CONFIG_FOR_DOC,
  713. )
  714. def forward(
  715. self,
  716. input_ids: Optional[torch.LongTensor] = None,
  717. inputs_embeds: Optional[torch.LongTensor] = None,
  718. cache_params: Optional[Mamba2Cache] = None,
  719. use_cache: Optional[bool] = None,
  720. output_hidden_states: Optional[bool] = None,
  721. return_dict: Optional[bool] = None,
  722. cache_position: Optional[torch.LongTensor] = None,
  723. attention_mask: Optional[torch.Tensor] = None,
  724. **kwargs,
  725. ) -> Union[Tuple, Mamba2Output]:
  726. output_hidden_states = (
  727. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  728. )
  729. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  730. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  731. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  732. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  733. if inputs_embeds is None:
  734. inputs_embeds = self.embeddings(input_ids)
  735. if self.gradient_checkpointing and self.training and use_cache:
  736. use_cache = False
  737. if use_cache:
  738. if cache_params is None:
  739. cache_params = Mamba2Cache(
  740. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  741. )
  742. cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
  743. elif cache_position is None:
  744. # cases when we do manual forward instead of using `model.generate` which will initiate
  745. # `cache_position` and makes sure it is not None, throw error here instead of doing some
  746. # hack to conjecture the current cache position
  747. raise ValueError(
  748. "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
  749. "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
  750. "be initialized for you automatically"
  751. )
  752. else:
  753. cache_params = None
  754. hidden_states = inputs_embeds
  755. all_hidden_states = () if output_hidden_states else None
  756. for mixer_block in self.layers:
  757. if self.gradient_checkpointing and self.training:
  758. hidden_states = self._gradient_checkpointing_func(
  759. mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
  760. )
  761. else:
  762. hidden_states = mixer_block(
  763. hidden_states,
  764. cache_params=cache_params,
  765. cache_position=cache_position,
  766. attention_mask=attention_mask,
  767. )
  768. if output_hidden_states:
  769. all_hidden_states = all_hidden_states + (hidden_states,)
  770. if use_cache:
  771. cache_params.seqlen_offset += inputs_embeds.shape[1]
  772. hidden_states = self.norm_f(hidden_states)
  773. if output_hidden_states:
  774. all_hidden_states = all_hidden_states + (hidden_states,)
  775. if not return_dict:
  776. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  777. return Mamba2Output(
  778. last_hidden_state=hidden_states,
  779. cache_params=cache_params if use_cache else None,
  780. hidden_states=all_hidden_states,
  781. )
  782. @add_start_docstrings(
  783. """
  784. The MAMBA2 Model transformer with a language modeling head on top (linear layer with weights not tied to the input
  785. embeddings).
  786. """,
  787. MAMBA2_START_DOCSTRING,
  788. )
  789. class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin):
  790. _tied_weights_keys = []
  791. def __init__(self, config):
  792. super().__init__(config)
  793. self.backbone = Mamba2Model(config)
  794. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  795. # Initialize weights and apply final processing
  796. self.post_init()
  797. def get_output_embeddings(self):
  798. return self.lm_head
  799. def set_output_embeddings(self, new_embeddings):
  800. self.lm_head = new_embeddings
  801. def get_input_embeddings(self):
  802. return self.backbone.get_input_embeddings()
  803. def set_input_embeddings(self, new_embeddings):
  804. return self.backbone.set_input_embeddings(new_embeddings)
  805. def prepare_inputs_for_generation(
  806. self,
  807. input_ids,
  808. inputs_embeds=None,
  809. use_cache=None,
  810. cache_params: Optional[Mamba2Cache] = None,
  811. cache_position: Optional[torch.LongTensor] = None,
  812. attention_mask: Optional[torch.Tensor] = None,
  813. **kwargs,
  814. ):
  815. # Overwitten -- uses `cache_params` as opposed to `past_key_values`
  816. if inputs_embeds is not None:
  817. past_len = inputs_embeds.shape[1] + input_ids.shape[1]
  818. else:
  819. past_len = input_ids.shape[1]
  820. if use_cache:
  821. # `cache_position` should have been initialized in `generate`
  822. if cache_position is None:
  823. raise ValueError(
  824. "`cache_position` should not be None as it should have been initialized in "
  825. "`model.generate`, you are responsible for passing in a valid `cache_position` if "
  826. "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
  827. )
  828. # how do we detect that we are in decoding without cache?
  829. if cache_position[0] > 0:
  830. input_ids = input_ids[:, -1][..., None]
  831. attention_mask = attention_mask[:, -1][..., None]
  832. else:
  833. # we initialize the `cache_position` to full size of `conv_states` at prefill stage
  834. # considering padding will be applied when input length is shorter, and truncation
  835. # will be applied when it is longer, so it will be equivalent to always have it match
  836. # the length of `cache_params.conv_states`, which is `config.conv_kernel`
  837. cache_position = torch.arange(0, past_len, device=input_ids.device)
  838. # if the cache is not used, we also do have to extend the attention mask here
  839. # TODO there is likely a cleverer way to do this
  840. extended_mask = torch.ones(
  841. attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
  842. )
  843. attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
  844. cache_params = None
  845. if attention_mask.shape[1] < past_len:
  846. # we have to update manually the attention mask if
  847. # we are in decoding without cache
  848. # and we don't have position_ids here
  849. # TODO but we should be able to use cache_position though at a later time
  850. extended_mask = torch.ones(
  851. attention_mask.size(0), past_len - attention_mask.shape[1], device=attention_mask.device
  852. )
  853. attention_mask = torch.cat([attention_mask, extended_mask], dim=1)
  854. if inputs_embeds is not None and cache_params is None:
  855. model_inputs = {"inputs_embeds": inputs_embeds}
  856. else:
  857. model_inputs = {"input_ids": input_ids}
  858. model_inputs.update(
  859. {
  860. "attention_mask": attention_mask,
  861. "cache_params": cache_params,
  862. "use_cache": use_cache,
  863. "cache_position": cache_position,
  864. }
  865. )
  866. return model_inputs
  867. @add_start_docstrings_to_model_forward(MAMBA2_INPUTS_DOCSTRING)
  868. @add_code_sample_docstrings(
  869. checkpoint=_CHECKPOINT_FOR_DOC,
  870. output_type=Mamba2CausalLMOutput,
  871. config_class=_CONFIG_FOR_DOC,
  872. )
  873. def forward(
  874. self,
  875. input_ids: Optional[torch.LongTensor] = None,
  876. inputs_embeds: Optional[torch.FloatTensor] = None,
  877. cache_params: Optional[Mamba2Cache] = None,
  878. labels: Optional[torch.LongTensor] = None,
  879. output_hidden_states: Optional[bool] = None,
  880. return_dict: Optional[bool] = None,
  881. use_cache: Optional[bool] = None,
  882. cache_position: Optional[torch.Tensor] = None,
  883. attention_mask: Optional[torch.Tensor] = None,
  884. **kwargs, # for now we need this for generation
  885. ) -> Union[Tuple, Mamba2CausalLMOutput]:
  886. r"""
  887. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  888. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  889. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  890. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  891. """
  892. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  893. mamba2_outputs = self.backbone(
  894. input_ids,
  895. cache_params=cache_params,
  896. inputs_embeds=inputs_embeds,
  897. output_hidden_states=output_hidden_states,
  898. return_dict=return_dict,
  899. use_cache=use_cache,
  900. cache_position=cache_position,
  901. attention_mask=attention_mask,
  902. )
  903. hidden_states = mamba2_outputs[0]
  904. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  905. loss = None
  906. if labels is not None:
  907. # move labels to correct device to enable model parallelism
  908. labels = labels.to(logits.device)
  909. # Shift so that tokens < n predict n
  910. shift_logits = logits[..., :-1, :].contiguous()
  911. shift_labels = labels[..., 1:].contiguous()
  912. # Flatten the tokens
  913. loss_fct = CrossEntropyLoss()
  914. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  915. if not return_dict:
  916. output = (logits,) + mamba2_outputs[1:]
  917. return ((loss,) + output) if loss is not None else output
  918. return Mamba2CausalLMOutput(
  919. loss=loss,
  920. logits=logits,
  921. cache_params=mamba2_outputs.cache_params,
  922. hidden_states=mamba2_outputs.hidden_states,
  923. )