modeling_mamba.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. # coding=utf-8
  2. # Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch MAMBA model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Dict, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import MambaCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_utils import PreTrainedModel
  27. from ...utils import (
  28. ModelOutput,
  29. add_code_sample_docstrings,
  30. add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. logging,
  33. )
  34. from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
  35. from .configuration_mamba import MambaConfig
  36. logger = logging.get_logger(__name__)
  37. if is_mambapy_available():
  38. from mambapy.pscan import pscan
  39. else:
  40. pscan = None
  41. if is_mamba_ssm_available():
  42. from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
  43. from mamba_ssm.ops.triton.selective_state_update import selective_state_update
  44. else:
  45. selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
  46. if is_causal_conv1d_available():
  47. from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
  48. else:
  49. causal_conv1d_update, causal_conv1d_fn = None, None
  50. is_fast_path_available = all(
  51. (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
  52. )
  53. _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m-hf"
  54. _CONFIG_FOR_DOC = "MambaConfig"
  55. class MambaMixer(nn.Module):
  56. """
  57. Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
  58. A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
  59. ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
  60. and is why Mamba is called **selective** state spaces)
  61. """
  62. def __init__(self, config: MambaConfig, layer_idx: int):
  63. super().__init__()
  64. self.config = config
  65. self.hidden_size = config.hidden_size
  66. self.ssm_state_size = config.state_size
  67. self.conv_kernel_size = config.conv_kernel
  68. self.intermediate_size = config.intermediate_size
  69. self.time_step_rank = int(config.time_step_rank)
  70. self.layer_idx = layer_idx
  71. self.use_conv_bias = config.use_conv_bias
  72. self.conv1d = nn.Conv1d(
  73. in_channels=self.intermediate_size,
  74. out_channels=self.intermediate_size,
  75. bias=config.use_conv_bias,
  76. kernel_size=config.conv_kernel,
  77. groups=self.intermediate_size,
  78. padding=config.conv_kernel - 1,
  79. )
  80. self.activation = config.hidden_act
  81. self.act = ACT2FN[config.hidden_act]
  82. self.use_mambapy = config.use_mambapy
  83. # projection of the input hidden states
  84. self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias)
  85. # selective projection used to make dt, B and C input dependant
  86. self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
  87. # time step projection (discretization)
  88. self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True)
  89. # S4D real initialization. These are not discretized!
  90. # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
  91. A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
  92. A = A.expand(self.intermediate_size, -1).contiguous()
  93. self.A_log = nn.Parameter(torch.log(A))
  94. self.D = nn.Parameter(torch.ones(self.intermediate_size))
  95. self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
  96. self.use_bias = config.use_bias
  97. if not is_fast_path_available:
  98. if self.use_mambapy:
  99. if is_mambapy_available():
  100. logger.warning_once(
  101. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  102. " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
  103. " https://github.com/Dao-AILab/causal-conv1d"
  104. )
  105. else:
  106. raise ImportError(
  107. "use_mambapy is set to True but the mambapy package is not installed. To install it follow https://github.com/alxndrTL/mamba.py."
  108. )
  109. else:
  110. logger.warning_once(
  111. "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
  112. " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
  113. " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
  114. )
  115. def cuda_kernels_forward(
  116. self,
  117. hidden_states: torch.Tensor,
  118. cache_params: Optional[MambaCache] = None,
  119. cache_position: Optional[torch.LongTensor] = None,
  120. attention_mask: Optional[torch.LongTensor] = None,
  121. ):
  122. # 1. Gated MLP's linear projection
  123. projected_states = self.in_proj(hidden_states).transpose(1, 2)
  124. if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
  125. contextualized_states = mamba_inner_fn(
  126. projected_states,
  127. self.conv1d.weight,
  128. self.conv1d.bias if self.use_conv_bias else None,
  129. self.x_proj.weight,
  130. self.dt_proj.weight,
  131. self.out_proj.weight,
  132. self.out_proj.bias.float() if self.use_bias else None,
  133. -torch.exp(self.A_log.float()),
  134. None, # input-dependent B
  135. None, # input-dependent C
  136. self.D.float(),
  137. delta_bias=self.dt_proj.bias.float(),
  138. delta_softplus=True,
  139. )
  140. else:
  141. hidden_states, gate = projected_states.chunk(2, dim=1)
  142. if attention_mask is not None:
  143. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  144. # 2. Convolution sequence transformation
  145. conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
  146. if cache_params is not None and cache_position[0] > 0:
  147. hidden_states = causal_conv1d_update(
  148. hidden_states.squeeze(-1),
  149. cache_params.conv_states[self.layer_idx],
  150. conv_weights,
  151. self.conv1d.bias,
  152. self.activation,
  153. )
  154. hidden_states = hidden_states.unsqueeze(-1)
  155. else:
  156. if cache_params is not None:
  157. conv_states = nn.functional.pad(
  158. hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
  159. )
  160. cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
  161. hidden_states = causal_conv1d_fn(
  162. hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
  163. )
  164. if attention_mask is not None:
  165. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  166. # 3. State Space Model sequence transformation
  167. # 3.a. input varying initialization of time_step, B and C
  168. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  169. time_step, B, C = torch.split(
  170. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  171. )
  172. discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
  173. A = -torch.exp(self.A_log.float())
  174. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  175. time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None
  176. if cache_params is not None and cache_position[0] > 0:
  177. scan_outputs = selective_state_update(
  178. cache_params.ssm_states[self.layer_idx],
  179. hidden_states[..., 0],
  180. discrete_time_step[..., 0],
  181. A,
  182. B[:, 0],
  183. C[:, 0],
  184. self.D,
  185. gate[..., 0],
  186. time_proj_bias,
  187. dt_softplus=True,
  188. ).unsqueeze(-1)
  189. else:
  190. scan_outputs, ssm_state = selective_scan_fn(
  191. hidden_states,
  192. discrete_time_step,
  193. A,
  194. B.transpose(1, 2),
  195. C.transpose(1, 2),
  196. self.D.float(),
  197. gate,
  198. time_proj_bias,
  199. delta_softplus=True,
  200. return_last_state=True,
  201. )
  202. if ssm_state is not None and cache_params is not None:
  203. cache_params.update_ssm_state(self.layer_idx, ssm_state)
  204. # 4. Final linear projection
  205. contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
  206. return contextualized_states
  207. # fmt: off
  208. def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.LongTensor] = None):
  209. batch_size, seq_len, _ = input_states.shape
  210. dtype = input_states.dtype
  211. # 1. Gated MLP's linear projection
  212. projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
  213. hidden_states, gate = projected_states.chunk(2, dim=1)
  214. if attention_mask is not None:
  215. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  216. # 2. Convolution sequence transformation
  217. if cache_params is not None:
  218. ssm_state = cache_params.ssm_states[self.layer_idx].clone()
  219. ssm_state = ssm_state.to(hidden_states.device)
  220. # use `cache_position.shape[0]` to check whether we are in prefill
  221. # stage, it's equivalent to check `cache_position[0] == 0`, which
  222. # breaks dynamo fullgraph constraints
  223. if cache_position.shape[0] == self.conv_kernel_size:
  224. conv_state = nn.functional.pad(
  225. hidden_states,
  226. (self.conv_kernel_size - hidden_states.shape[-1], 0)
  227. )
  228. cache_params.update_conv_state(self.layer_idx, conv_state, cache_position)
  229. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  230. else:
  231. conv_state = cache_params.update_conv_state(self.layer_idx, hidden_states, cache_position)
  232. hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1)
  233. if self.use_conv_bias:
  234. hidden_states += self.conv1d.bias
  235. hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding
  236. else:
  237. ssm_state = torch.zeros(
  238. (batch_size, self.intermediate_size, self.ssm_state_size),
  239. device=hidden_states.device, dtype=dtype
  240. )
  241. hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
  242. if attention_mask is not None:
  243. hidden_states = hidden_states * attention_mask.unsqueeze(1)
  244. # 3. State Space Model sequence transformation
  245. # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
  246. ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
  247. time_step, B, C = torch.split(
  248. ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
  249. )
  250. discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
  251. discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len]
  252. # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM)
  253. A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size]
  254. discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size]
  255. discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediate_size, seq_len, ssm_state_size]
  256. deltaB_u = discrete_B * hidden_states[:, :, :, None].float()
  257. # 3.c perform the recurrence y ← SSM(A, B, C)(x)
  258. if self.use_mambapy and self.training and cache_params is None:
  259. hs = pscan(discrete_A.transpose(1, 2), deltaB_u.transpose(1, 2)) # [batch, seq_len, intermediate_size, ssm_state_size]
  260. scan_output = (hs @ C.unsqueeze(-1)).squeeze(3).transpose(1, 2) # [batch, intermediate_size, seq_len]
  261. scan_output = scan_output + hidden_states * self.D[None, :, None]
  262. scan_output = scan_output * self.act(gate)
  263. else:
  264. scan_outputs = []
  265. for i in range(seq_len):
  266. ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state]
  267. scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1]
  268. scan_outputs.append(scan_output[:, :, 0])
  269. scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size]
  270. scan_output = scan_output + (hidden_states * self.D[None, :, None])
  271. scan_output = (scan_output * self.act(gate))
  272. if cache_params is not None:
  273. cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
  274. # 4. Final linear projection
  275. contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
  276. return contextualized_states
  277. # fmt: on
  278. def forward(
  279. self,
  280. hidden_states,
  281. cache_params: Optional[MambaCache] = None,
  282. cache_position: Optional[torch.LongTensor] = None,
  283. attention_mask: Optional[torch.LongTensor] = None,
  284. ):
  285. if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not torch._dynamo.is_compiling():
  286. return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
  287. return self.slow_forward(hidden_states, cache_params, cache_position, attention_mask)
  288. class MambaRMSNorm(nn.Module):
  289. def __init__(self, hidden_size, eps=1e-6):
  290. """
  291. MambaRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm
  292. """
  293. super().__init__()
  294. self.weight = nn.Parameter(torch.ones(hidden_size))
  295. self.variance_epsilon = eps
  296. def forward(self, hidden_states):
  297. input_dtype = hidden_states.dtype
  298. hidden_states = hidden_states.to(torch.float32)
  299. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  300. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  301. return self.weight * hidden_states.to(input_dtype)
  302. def extra_repr(self):
  303. return f"{self.weight.shape[0]}, eps={self.variance_epsilon}"
  304. class MambaBlock(nn.Module):
  305. def __init__(self, config, layer_idx):
  306. super().__init__()
  307. self.config = config
  308. self.layer_idx = layer_idx
  309. self.residual_in_fp32 = config.residual_in_fp32
  310. self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  311. self.mixer = MambaMixer(config, layer_idx=layer_idx)
  312. def forward(
  313. self,
  314. hidden_states,
  315. cache_params: Optional[MambaCache] = None,
  316. cache_position: Optional[torch.LongTensor] = None,
  317. attention_mask: Optional[torch.LongTensor] = None,
  318. ):
  319. residual = hidden_states
  320. hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
  321. if self.residual_in_fp32:
  322. residual = residual.to(torch.float32)
  323. hidden_states = self.mixer(
  324. hidden_states, cache_params=cache_params, cache_position=cache_position, attention_mask=attention_mask
  325. )
  326. hidden_states = residual + hidden_states
  327. return hidden_states
  328. class MambaPreTrainedModel(PreTrainedModel):
  329. """
  330. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  331. models.
  332. """
  333. config_class = MambaConfig
  334. base_model_prefix = "backbone"
  335. _no_split_modules = ["MambaBlock", "MambaMixer"]
  336. supports_gradient_checkpointing = True
  337. _is_stateful = True
  338. def _init_weights(self, module):
  339. """Initialize the weights."""
  340. if isinstance(module, MambaMixer):
  341. module.A_log._no_weight_decay = True
  342. module.D._no_weight_decay = True
  343. dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale
  344. if self.config.time_step_init_scheme == "constant":
  345. nn.init.constant_(module.dt_proj.weight, dt_init_std)
  346. elif self.config.time_step_init_scheme == "random":
  347. nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std)
  348. dt = torch.exp(
  349. torch.rand(self.config.intermediate_size)
  350. * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
  351. + math.log(self.config.time_step_min)
  352. ).clamp(min=self.config.time_step_floor)
  353. # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  354. inv_dt = dt + torch.log(-torch.expm1(-dt))
  355. with torch.no_grad():
  356. module.dt_proj.bias.copy_(inv_dt)
  357. module.dt_proj.bias._no_reinit = True
  358. if isinstance(module, nn.Linear):
  359. if module.bias is not None:
  360. if not getattr(module.bias, "_no_reinit", False):
  361. nn.init.zeros_(module.bias)
  362. elif isinstance(module, nn.Embedding):
  363. nn.init.normal_(module.weight, std=self.config.initializer_range)
  364. if self.config.rescale_prenorm_residual:
  365. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  366. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  367. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  368. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  369. #
  370. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  371. for name, p in module.named_parameters():
  372. if name in ["out_proj.weight"]:
  373. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  374. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  375. # We need to reinit p since this code could be called multiple times
  376. # Having just p *= scale would repeatedly scale it down
  377. nn.init.kaiming_uniform_(p, a=math.sqrt(5))
  378. with torch.no_grad():
  379. p /= math.sqrt(self.config.num_hidden_layers)
  380. @dataclass
  381. class MambaOutput(ModelOutput):
  382. """
  383. Class for the MAMBA model outputs.
  384. Args:
  385. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  386. Sequence of hidden-states at the output of the last layer of the model.
  387. cache_params (`MambaCache`):
  388. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  389. avoid providing the old `input_ids`.
  390. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  391. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  392. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  393. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  394. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  395. """
  396. last_hidden_state: Optional[torch.FloatTensor] = None
  397. cache_params: Optional[MambaCache] = None
  398. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  399. @dataclass
  400. class MambaCausalLMOutput(ModelOutput):
  401. """
  402. Base class for causal language model (or autoregressive) outputs.
  403. Args:
  404. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  405. Language modeling loss (for next-token prediction).
  406. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  407. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  408. cache_params (`MambaCache`):
  409. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  410. avoid providing the old `input_ids`.
  411. Includes both the State space model state matrices after the selective scan, and the Convolutional states
  412. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  413. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  414. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
  415. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
  416. """
  417. loss: Optional[torch.FloatTensor] = None
  418. logits: Optional[torch.FloatTensor] = None
  419. cache_params: Optional[MambaCache] = None
  420. hidden_states: Optional[Tuple[torch.FloatTensor]] = None
  421. MAMBA_START_DOCSTRING = r"""
  422. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  423. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  424. etc.)
  425. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  426. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  427. and behavior.
  428. Parameters:
  429. config ([`MambaConfig`]): Model configuration class with all the parameters of the model.
  430. Initializing with a config file does not load the weights associated with the model, only the
  431. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  432. """
  433. MAMBA_INPUTS_DOCSTRING = r"""
  434. Args:
  435. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  436. Indices of input sequence tokens in the vocabulary.
  437. If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
  438. `input_ids`.
  439. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  440. [`PreTrainedTokenizer.__call__`] for details.
  441. [What are input IDs?](../glossary#input-ids)
  442. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  443. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  444. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  445. model's internal embedding lookup matrix.
  446. cache_params (`MambaCache`, *optional*):
  447. If passed along, the model uses the previous state in all the blocks (which will give the output for the
  448. `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
  449. use_cache (`bool`, *optional*):
  450. If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits.
  451. output_hidden_states (`bool`, *optional*):
  452. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  453. more detail.
  454. return_dict (`bool`, *optional*):
  455. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  456. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  457. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  458. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  459. the complete sequence length.
  460. """
  461. @add_start_docstrings(
  462. "The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.",
  463. MAMBA_START_DOCSTRING,
  464. )
  465. class MambaModel(MambaPreTrainedModel):
  466. def __init__(self, config):
  467. super().__init__(config)
  468. self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
  469. self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
  470. self.gradient_checkpointing = False
  471. self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  472. # Initialize weights and apply final processing
  473. self._register_load_state_dict_pre_hook(self.load_hook)
  474. self.post_init()
  475. def load_hook(self, state_dict, prefix, *args):
  476. for k in state_dict:
  477. if "embedding." in k:
  478. state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
  479. break
  480. def get_input_embeddings(self):
  481. return self.embeddings
  482. def set_input_embeddings(self, new_embeddings):
  483. self.embeddings = new_embeddings
  484. @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
  485. @add_code_sample_docstrings(
  486. checkpoint=_CHECKPOINT_FOR_DOC,
  487. output_type=MambaOutput,
  488. config_class=_CONFIG_FOR_DOC,
  489. )
  490. def forward(
  491. self,
  492. input_ids: Optional[torch.LongTensor] = None,
  493. inputs_embeds: Optional[torch.LongTensor] = None,
  494. cache_params: Optional[MambaCache] = None,
  495. use_cache: Optional[bool] = None,
  496. output_hidden_states: Optional[bool] = None,
  497. return_dict: Optional[bool] = None,
  498. cache_position: Optional[torch.LongTensor] = None,
  499. attention_mask: Optional[torch.LongTensor] = None,
  500. ) -> Union[Tuple, MambaOutput]:
  501. output_hidden_states = (
  502. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  503. )
  504. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  505. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  506. if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
  507. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  508. if inputs_embeds is None:
  509. inputs_embeds = self.embeddings(input_ids)
  510. if self.gradient_checkpointing and self.training and use_cache:
  511. use_cache = False
  512. if use_cache:
  513. if cache_params is None:
  514. cache_params = MambaCache(
  515. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  516. )
  517. cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
  518. elif cache_position is None:
  519. # cases when we do manual forward instead of using `model.generate` which will initiate
  520. # `cache_position` and makes sure it is not None, throw error here instead of doing some
  521. # hack to conjecture the current cache position
  522. raise ValueError(
  523. "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
  524. "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
  525. "be initialized for you automatically"
  526. )
  527. else:
  528. cache_params = None
  529. hidden_states = inputs_embeds
  530. all_hidden_states = () if output_hidden_states else None
  531. for mixer_block in self.layers:
  532. if self.gradient_checkpointing and self.training:
  533. hidden_states = self._gradient_checkpointing_func(
  534. mixer_block.__call__, hidden_states, cache_params, cache_position, attention_mask
  535. )
  536. else:
  537. hidden_states = mixer_block(
  538. hidden_states,
  539. cache_params=cache_params,
  540. cache_position=cache_position,
  541. attention_mask=attention_mask,
  542. )
  543. if output_hidden_states:
  544. all_hidden_states = all_hidden_states + (hidden_states,)
  545. hidden_states = self.norm_f(hidden_states)
  546. if output_hidden_states:
  547. all_hidden_states = all_hidden_states + (hidden_states,)
  548. if not return_dict:
  549. return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None)
  550. return MambaOutput(
  551. last_hidden_state=hidden_states,
  552. cache_params=cache_params if use_cache else None,
  553. hidden_states=all_hidden_states,
  554. )
  555. @add_start_docstrings(
  556. """
  557. The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input
  558. embeddings).
  559. """,
  560. MAMBA_START_DOCSTRING,
  561. )
  562. class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin):
  563. _tied_weights_keys = ["lm_head.weight"]
  564. def __init__(self, config):
  565. super().__init__(config)
  566. self.backbone = MambaModel(config)
  567. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  568. # Initialize weights and apply final processing
  569. self.post_init()
  570. def get_output_embeddings(self):
  571. return self.lm_head
  572. def set_output_embeddings(self, new_embeddings):
  573. self.lm_head = new_embeddings
  574. def get_input_embeddings(self):
  575. return self.backbone.get_input_embeddings()
  576. def set_input_embeddings(self, new_embeddings):
  577. return self.backbone.set_input_embeddings(new_embeddings)
  578. def _update_model_kwargs_for_generation(
  579. self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
  580. ) -> Dict[str, Any]:
  581. model_kwargs["cache_params"] = outputs.get("cache_params", None)
  582. if (
  583. model_kwargs.get("use_cache", True)
  584. and "cache_position" in model_kwargs
  585. and model_kwargs["cache_position"] is not None
  586. ):
  587. model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
  588. if "attention_mask" in model_kwargs:
  589. attention_mask = model_kwargs["attention_mask"]
  590. model_kwargs["attention_mask"] = torch.cat(
  591. [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
  592. )
  593. return model_kwargs
  594. def prepare_inputs_for_generation(
  595. self,
  596. input_ids,
  597. inputs_embeds=None,
  598. use_cache=None,
  599. cache_params: Optional[MambaCache] = None,
  600. cache_position: Optional[torch.LongTensor] = None,
  601. attention_mask: Optional[torch.LongTensor] = None,
  602. **kwargs,
  603. ):
  604. # Overwitten -- uses `cache_params` as opposed to `past_key_values`
  605. if use_cache:
  606. # `cache_position` should have been initialized in `generate`
  607. if cache_position is None:
  608. raise ValueError(
  609. "`cache_position` should not be None as it should have been initialized in "
  610. "`model.generate`, you are responsible for passing in a valid `cache_position` if "
  611. "you are calling `prepare_inputs_for_generation` directly with `use_cache=True`"
  612. )
  613. if cache_position[0] > 0:
  614. input_ids = input_ids[:, -1].unsqueeze(-1)
  615. if attention_mask is not None:
  616. attention_mask = None
  617. else:
  618. # we initialize the `cache_position` to full size of `conv_states` at prefill stage
  619. # considering padding will be applied when input length is shorter, and truncation
  620. # will be applied when it is longer, so it will be equivalent to always have it match
  621. # the length of `cache_params.conv_states`, which is `config.conv_kernel`
  622. cache_position = torch.arange(0, self.config.conv_kernel, device=input_ids.device)
  623. if inputs_embeds is not None and cache_params is None:
  624. model_inputs = {"inputs_embeds": inputs_embeds}
  625. else:
  626. model_inputs = {"input_ids": input_ids.contiguous()}
  627. model_inputs.update(
  628. {
  629. "cache_params": cache_params,
  630. "use_cache": use_cache,
  631. "cache_position": cache_position,
  632. "attention_mask": attention_mask,
  633. }
  634. )
  635. return model_inputs
  636. @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING)
  637. @add_code_sample_docstrings(
  638. checkpoint=_CHECKPOINT_FOR_DOC,
  639. output_type=MambaCausalLMOutput,
  640. config_class=_CONFIG_FOR_DOC,
  641. )
  642. def forward(
  643. self,
  644. input_ids: Optional[torch.LongTensor] = None,
  645. attention_mask: Optional[torch.LongTensor] = None,
  646. inputs_embeds: Optional[torch.FloatTensor] = None,
  647. cache_params: Optional[MambaCache] = None,
  648. labels: Optional[torch.LongTensor] = None,
  649. output_hidden_states: Optional[bool] = None,
  650. return_dict: Optional[bool] = None,
  651. use_cache: Optional[bool] = None,
  652. cache_position: Optional[torch.Tensor] = None,
  653. **kwargs, # for now we need this for generation
  654. ) -> Union[Tuple, MambaCausalLMOutput]:
  655. r"""
  656. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  657. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  658. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  659. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  660. """
  661. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  662. mamba_outputs = self.backbone(
  663. input_ids,
  664. cache_params=cache_params,
  665. inputs_embeds=inputs_embeds,
  666. output_hidden_states=output_hidden_states,
  667. return_dict=return_dict,
  668. use_cache=use_cache,
  669. cache_position=cache_position,
  670. attention_mask=attention_mask,
  671. )
  672. hidden_states = mamba_outputs[0]
  673. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  674. loss = None
  675. if labels is not None:
  676. # move labels to correct device to enable model parallelism
  677. labels = labels.to(logits.device)
  678. # Shift so that tokens < n predict n
  679. shift_logits = logits[..., :-1, :].contiguous()
  680. shift_labels = labels[..., 1:].contiguous()
  681. # Flatten the tokens
  682. loss_fct = CrossEntropyLoss()
  683. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  684. if not return_dict:
  685. output = (logits,) + mamba_outputs[1:]
  686. return ((loss,) + output) if loss is not None else output
  687. return MambaCausalLMOutput(
  688. loss=loss,
  689. logits=logits,
  690. cache_params=mamba_outputs.cache_params,
  691. hidden_states=mamba_outputs.hidden_states,
  692. )