modeling_bloom.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364
  1. # coding=utf-8
  2. # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch BLOOM model."""
  16. import math
  17. import warnings
  18. from typing import Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
  23. from torch.nn import functional as F
  24. from ...cache_utils import Cache, DynamicCache, StaticCache
  25. from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import AttentionMaskConverter
  28. from ...modeling_outputs import (
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. CausalLMOutputWithCrossAttentions,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutputWithPast,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...utils import logging
  37. from .configuration_bloom import BloomConfig
  38. logger = logging.get_logger(__name__)
  39. _CHECKPOINT_FOR_DOC = "bigscience/bloom-560m"
  40. _CONFIG_FOR_DOC = "BloomConfig"
  41. def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  42. """
  43. Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
  44. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  45. `softmax(l+a) = softmax(l)`. Based on
  46. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  47. TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
  48. Args:
  49. Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
  50. attention_mask (`torch.Tensor`):
  51. Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
  52. num_heads (`int`):
  53. number of heads
  54. dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
  55. dtype of the output tensor
  56. """
  57. batch_size, seq_length = attention_mask.shape
  58. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
  59. base = torch.tensor(
  60. 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  61. )
  62. powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
  63. slopes = torch.pow(base, powers)
  64. if closest_power_of_2 != num_heads:
  65. extra_base = torch.tensor(
  66. 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
  67. )
  68. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  69. extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
  70. slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
  71. # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
  72. # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  73. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  74. # => the query_length dimension will then be broadcasted correctly
  75. # This is more or less identical to T5's relative position bias:
  76. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  77. arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
  78. alibi = slopes[..., None] * arange_tensor
  79. return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
  80. def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
  81. """
  82. Dropout add function
  83. Args:
  84. x (`torch.tensor`):
  85. input tensor
  86. residual (`torch.tensor`):
  87. residual tensor
  88. prob (`float`):
  89. dropout probability
  90. training (`bool`):
  91. training mode
  92. """
  93. out = F.dropout(x, p=prob, training=training)
  94. out = residual + out
  95. return out
  96. def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
  97. """
  98. Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
  99. make the model jitable.
  100. Args:
  101. x (`torch.tensor`):
  102. input hidden states
  103. """
  104. return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  105. def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
  106. """
  107. gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
  108. 0.3989423 * x * torch.exp(-0.5 * x * x)
  109. Args:
  110. g (`torch.tensor`):
  111. gradient output tensor
  112. x (`torch.tensor`):
  113. input tensor
  114. """
  115. x = x[0] # x is a tuple of 1 element, needs to unpack it first
  116. tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
  117. # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
  118. ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
  119. return ff * g
  120. class GeLUFunction(torch.autograd.Function):
  121. @staticmethod
  122. def forward(ctx, input: torch.Tensor) -> torch.Tensor:
  123. ctx.save_for_backward(input)
  124. return bloom_gelu_forward(input)
  125. @staticmethod
  126. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
  127. input = ctx.saved_tensors
  128. tmp = bloom_gelu_back(grad_output, input)
  129. return tmp
  130. class BloomGelu(nn.Module):
  131. """
  132. BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
  133. torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
  134. copied from Megatron-DeepSpeed code and adapted for our needs
  135. See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
  136. """
  137. def __init__(self):
  138. super().__init__()
  139. def forward(self, x: torch.Tensor) -> torch.Tensor:
  140. if self.training:
  141. return GeLUFunction.apply(x)
  142. else:
  143. return bloom_gelu_forward(x)
  144. class BloomAttention(nn.Module):
  145. def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
  146. super().__init__()
  147. self.pretraining_tp = config.pretraining_tp
  148. self.slow_but_exact = config.slow_but_exact
  149. self.hidden_size = config.hidden_size
  150. self.num_heads = config.n_head
  151. self.head_dim = self.hidden_size // self.num_heads
  152. self.split_size = self.hidden_size
  153. self.hidden_dropout = config.hidden_dropout
  154. if self.head_dim * self.num_heads != self.hidden_size:
  155. raise ValueError(
  156. f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
  157. f" {self.num_heads})."
  158. )
  159. # Layer-wise attention scaling
  160. self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
  161. self.beta = 1.0
  162. self.layer_idx = layer_idx
  163. if layer_idx is None:
  164. logger.warning_once(
  165. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  166. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  167. "when creating this class."
  168. )
  169. self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
  170. self.dense = nn.Linear(self.hidden_size, self.hidden_size)
  171. self.attention_dropout = nn.Dropout(config.attention_dropout)
  172. def _reshape(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  173. """
  174. Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape
  175. without making any copies, results share same memory storage as `fused_qkv`
  176. Args:
  177. fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
  178. Returns:
  179. query: [batch_size, num_heads, seq_length, head_dim]
  180. key: [batch_size, num_heads, seq_length, head_dim]
  181. value: [batch_size, num_heads, seq_length, head_dim]
  182. """
  183. batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
  184. fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
  185. query_layer = fused_qkv[..., 0, :].transpose(1, 2)
  186. key_layer = fused_qkv[..., 1, :].transpose(1, 2)
  187. value_layer = fused_qkv[..., 2, :].transpose(1, 2)
  188. return query_layer, key_layer, value_layer
  189. def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
  190. """
  191. Merge heads together over the last dimension
  192. Args:
  193. x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim]
  194. Returns:
  195. torch.tensor: [batch_size, seq_length, num_heads * head_dim]
  196. """
  197. # What we want to achieve is:
  198. # batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
  199. batch_size_and_num_heads, seq_length, _ = x.shape
  200. batch_size = batch_size_and_num_heads // self.num_heads
  201. # First view to decompose the batch size
  202. # batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
  203. x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
  204. # batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
  205. x = x.permute(0, 2, 1, 3)
  206. # batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
  207. return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
  208. def forward(
  209. self,
  210. hidden_states: torch.Tensor,
  211. residual: torch.Tensor,
  212. alibi: torch.Tensor,
  213. attention_mask: torch.Tensor,
  214. layer_past: Optional[Cache] = None,
  215. head_mask: Optional[torch.Tensor] = None,
  216. use_cache: bool = False,
  217. output_attentions: bool = False,
  218. cache_position: Optional[torch.LongTensor] = None,
  219. ):
  220. batch_size, q_length, _ = hidden_states.shape
  221. fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
  222. # 3 x [batch_size, num_heads, seq_length, head_dim]
  223. query_layer, key_layer, value_layer = self._reshape(fused_qkv)
  224. if layer_past is not None:
  225. cache_kwargs = {"cache_position": cache_position}
  226. key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
  227. # reshape qkv for further computations
  228. query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  229. key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2)
  230. value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim)
  231. # [batch_size * num_heads, q_length, kv_length]
  232. attention_scores = alibi.baddbmm(
  233. batch1=query_layer,
  234. batch2=key_layer,
  235. beta=self.beta,
  236. alpha=self.inv_norm_factor,
  237. )
  238. # change view to [batch_size, num_heads, q_length, kv_length]
  239. attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1)
  240. if attention_mask is not None: # no matter the length, we just slice it
  241. causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]]
  242. attn_weights = attn_weights + causal_mask
  243. # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
  244. attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype)
  245. # [batch_size, num_heads, q_length, kv_length]
  246. attention_probs = self.attention_dropout(attention_probs)
  247. if head_mask is not None:
  248. attention_probs = attention_probs * head_mask
  249. # change view [batch_size x num_heads, q_length, kv_length]
  250. attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1)
  251. # matmul: [batch_size * num_heads, q_length, head_dim]
  252. context_layer = torch.bmm(attention_probs_reshaped, value_layer)
  253. # change view [batch_size, q_length, num_heads * head_dim]
  254. context_layer = self._merge_heads(context_layer)
  255. # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
  256. if self.pretraining_tp > 1 and self.slow_but_exact:
  257. slices = self.hidden_size / self.pretraining_tp
  258. output_tensor = torch.zeros_like(context_layer)
  259. for i in range(self.pretraining_tp):
  260. output_tensor = output_tensor + F.linear(
  261. context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
  262. self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
  263. )
  264. else:
  265. output_tensor = self.dense(context_layer)
  266. output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
  267. outputs = (output_tensor, layer_past)
  268. if output_attentions:
  269. outputs += (attention_probs,)
  270. return outputs
  271. class BloomMLP(nn.Module):
  272. def __init__(self, config: BloomConfig):
  273. super().__init__()
  274. hidden_size = config.hidden_size
  275. self.pretraining_tp = config.pretraining_tp
  276. self.slow_but_exact = config.slow_but_exact
  277. self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size)
  278. self.gelu_impl = BloomGelu()
  279. self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
  280. self.hidden_dropout = config.hidden_dropout
  281. def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
  282. hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
  283. if self.pretraining_tp > 1 and self.slow_but_exact:
  284. intermediate_output = torch.zeros_like(residual)
  285. slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
  286. for i in range(self.pretraining_tp):
  287. intermediate_output = intermediate_output + F.linear(
  288. hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
  289. self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
  290. )
  291. else:
  292. intermediate_output = self.dense_4h_to_h(hidden_states)
  293. output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
  294. return output
  295. class BloomBlock(nn.Module):
  296. def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None):
  297. super().__init__()
  298. hidden_size = config.hidden_size
  299. self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  300. self.num_heads = config.n_head
  301. self.self_attention = BloomAttention(config, layer_idx)
  302. self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
  303. self.mlp = BloomMLP(config)
  304. self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
  305. self.hidden_dropout = config.hidden_dropout
  306. def forward(
  307. self,
  308. hidden_states: torch.Tensor,
  309. alibi: torch.Tensor,
  310. attention_mask: torch.Tensor,
  311. layer_past: Optional[Cache] = None,
  312. head_mask: Optional[torch.Tensor] = None,
  313. use_cache: bool = False,
  314. output_attentions: bool = False,
  315. cache_position: Optional[torch.LongTensor] = None,
  316. ):
  317. # hidden_states: [batch_size, seq_length, hidden_size]
  318. # Layer norm at the beginning of the transformer layer.
  319. layernorm_output = self.input_layernorm(hidden_states)
  320. # Layer norm post the self attention.
  321. if self.apply_residual_connection_post_layernorm:
  322. residual = layernorm_output
  323. else:
  324. residual = hidden_states
  325. # Self attention.
  326. attn_outputs = self.self_attention(
  327. layernorm_output,
  328. residual,
  329. layer_past=layer_past,
  330. attention_mask=attention_mask,
  331. alibi=alibi,
  332. head_mask=head_mask,
  333. use_cache=use_cache,
  334. output_attentions=output_attentions,
  335. cache_position=cache_position,
  336. )
  337. attention_output = attn_outputs[0]
  338. outputs = attn_outputs[1:]
  339. layernorm_output = self.post_attention_layernorm(attention_output)
  340. # Get residual
  341. if self.apply_residual_connection_post_layernorm:
  342. residual = layernorm_output
  343. else:
  344. residual = attention_output
  345. # MLP.
  346. output = self.mlp(layernorm_output, residual)
  347. if use_cache:
  348. outputs = (output,) + outputs
  349. else:
  350. outputs = (output,) + outputs[1:]
  351. return outputs # hidden_states, past_kv, attentions
  352. class BloomPreTrainedModel(PreTrainedModel):
  353. config_class = BloomConfig
  354. base_model_prefix = "transformer"
  355. supports_gradient_checkpointing = True
  356. _no_split_modules = ["BloomBlock"]
  357. _skip_keys_device_placement = "past_key_values"
  358. _supports_cache_class = True
  359. _supports_static_cache = True
  360. _supports_quantized_cache = True
  361. def __init__(self, *inputs, **kwargs):
  362. super().__init__(*inputs, **kwargs)
  363. def _init_weights(self, module: nn.Module):
  364. """Initialize the weights."""
  365. if isinstance(module, nn.Linear):
  366. # Slightly different from the TF version which uses truncated_normal for initialization
  367. # cf https://github.com/pytorch/pytorch/pull/5617
  368. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  369. if module.bias is not None:
  370. module.bias.data.zero_()
  371. elif isinstance(module, nn.Embedding):
  372. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  373. if module.padding_idx is not None:
  374. module.weight.data[module.padding_idx].zero_()
  375. elif isinstance(module, LayerNorm):
  376. module.bias.data.zero_()
  377. module.weight.data.fill_(1.0)
  378. BLOOM_START_DOCSTRING = r"""
  379. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  380. library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
  381. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  382. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  383. and behavior.
  384. Parameters:
  385. config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
  386. Initializing with a config file does not load the weights associated with the model, only the
  387. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  388. """
  389. BLOOM_INPUTS_DOCSTRING = r"""
  390. Args:
  391. input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
  392. `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
  393. (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
  394. If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
  395. `input_ids`.
  396. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  397. [`PreTrainedTokenizer.__call__`] for details.
  398. [What are input IDs?](../glossary#input-ids)
  399. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  400. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  401. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  402. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  403. Two formats are allowed:
  404. - a [`~cache_utils.Cache`] instance, see our
  405. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  406. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  407. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  408. cache format.
  409. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  410. legacy cache format will be returned.
  411. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  412. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  413. of shape `(batch_size, sequence_length)`.
  414. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  415. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  416. - 1 for tokens that are **not masked**,
  417. - 0 for tokens that are **masked**.
  418. [What are attention masks?](../glossary#attention-mask)
  419. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  420. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  421. - 1 indicates the head is **not masked**,
  422. - 0 indicates the head is **masked**.
  423. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  424. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  425. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  426. model's internal embedding lookup matrix.
  427. If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
  428. `past_key_values`).
  429. use_cache (`bool`, *optional*):
  430. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  431. `past_key_values`).
  432. output_attentions (`bool`, *optional*):
  433. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  434. tensors for more detail.
  435. output_hidden_states (`bool`, *optional*):
  436. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  437. more detail.
  438. return_dict (`bool`, *optional*):
  439. Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
  440. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  441. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  442. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  443. the complete sequence length.
  444. """
  445. @add_start_docstrings(
  446. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  447. BLOOM_START_DOCSTRING,
  448. )
  449. class BloomModel(BloomPreTrainedModel):
  450. def __init__(self, config: BloomConfig):
  451. super().__init__(config)
  452. self.embed_dim = config.hidden_size
  453. self.num_heads = config.n_head
  454. # Embedding + LN Embedding
  455. self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
  456. self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  457. # Transformer blocks
  458. self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  459. # Final Layer Norm
  460. self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
  461. self.gradient_checkpointing = False
  462. # Initialize weights and apply final processing
  463. self.post_init()
  464. def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
  465. return build_alibi_tensor(attention_mask, num_heads, dtype)
  466. def get_input_embeddings(self):
  467. return self.word_embeddings
  468. def set_input_embeddings(self, new_embeddings: torch.Tensor):
  469. self.word_embeddings = new_embeddings
  470. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  471. @add_code_sample_docstrings(
  472. checkpoint=_CHECKPOINT_FOR_DOC,
  473. output_type=BaseModelOutputWithPastAndCrossAttentions,
  474. config_class=_CONFIG_FOR_DOC,
  475. )
  476. def forward(
  477. self,
  478. input_ids: Optional[torch.LongTensor] = None,
  479. past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  480. attention_mask: Optional[torch.Tensor] = None,
  481. head_mask: Optional[torch.LongTensor] = None,
  482. inputs_embeds: Optional[torch.LongTensor] = None,
  483. use_cache: Optional[bool] = None,
  484. output_attentions: Optional[bool] = None,
  485. output_hidden_states: Optional[bool] = None,
  486. return_dict: Optional[bool] = None,
  487. cache_position: Optional[torch.LongTensor] = None,
  488. **deprecated_arguments,
  489. ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
  490. if deprecated_arguments.pop("position_ids", False) is not False:
  491. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  492. warnings.warn(
  493. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  494. " passing `position_ids`.",
  495. FutureWarning,
  496. )
  497. if len(deprecated_arguments) > 0:
  498. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  499. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  500. output_hidden_states = (
  501. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  502. )
  503. use_cache = use_cache if use_cache is not None else self.config.use_cache
  504. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  505. if (input_ids is None) ^ (inputs_embeds is not None):
  506. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  507. if self.gradient_checkpointing and self.training and use_cache:
  508. logger.warning_once(
  509. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  510. )
  511. use_cache = False
  512. if inputs_embeds is None:
  513. inputs_embeds = self.word_embeddings(input_ids)
  514. # kept for BC (non `Cache` `past_key_values` inputs)
  515. return_legacy_cache = False
  516. if use_cache and not isinstance(past_key_values, Cache):
  517. return_legacy_cache = True
  518. if past_key_values is None:
  519. past_key_values = DynamicCache()
  520. else:
  521. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  522. logger.warning_once(
  523. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  524. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  525. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  526. )
  527. batch_size, seq_length, _ = inputs_embeds.shape
  528. past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  529. seq_length_with_past = seq_length + past_length
  530. if cache_position is None:
  531. cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
  532. # Prepare head mask if needed
  533. # 1.0 in head_mask indicate we keep the head
  534. # attention_probs has shape batch_size x num_heads x N x N
  535. # head_mask has shape n_layer x batch x num_heads x N x N
  536. head_mask = self.get_head_mask(head_mask, self.config.n_layer)
  537. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  538. next_decoder_cache = None
  539. all_self_attentions = () if output_attentions else None
  540. all_hidden_states = () if output_hidden_states else None
  541. # Compute alibi tensor: check build_alibi_tensor documentation
  542. if attention_mask is None:
  543. attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
  544. else:
  545. attention_mask = attention_mask.to(hidden_states.device)
  546. alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
  547. causal_mask = self._update_causal_mask(
  548. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  549. )
  550. for i, block in enumerate(self.h):
  551. if output_hidden_states:
  552. all_hidden_states = all_hidden_states + (hidden_states,)
  553. if self.gradient_checkpointing and self.training:
  554. outputs = self._gradient_checkpointing_func(
  555. block.__call__,
  556. hidden_states,
  557. alibi,
  558. causal_mask,
  559. past_key_values,
  560. head_mask[i],
  561. use_cache,
  562. output_attentions,
  563. cache_position,
  564. )
  565. else:
  566. outputs = block(
  567. hidden_states,
  568. layer_past=past_key_values,
  569. attention_mask=causal_mask,
  570. head_mask=head_mask[i],
  571. use_cache=use_cache,
  572. output_attentions=output_attentions,
  573. alibi=alibi,
  574. cache_position=cache_position,
  575. )
  576. hidden_states = outputs[0]
  577. if use_cache:
  578. next_decoder_cache = outputs[1]
  579. if output_attentions:
  580. all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
  581. # Add last hidden state
  582. hidden_states = self.ln_f(hidden_states)
  583. if output_hidden_states:
  584. all_hidden_states = all_hidden_states + (hidden_states,)
  585. next_cache = next_decoder_cache if use_cache else None
  586. if return_legacy_cache:
  587. next_cache = next_cache.to_legacy_cache()
  588. if not return_dict:
  589. return tuple(
  590. v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
  591. )
  592. return BaseModelOutputWithPastAndCrossAttentions(
  593. last_hidden_state=hidden_states,
  594. past_key_values=next_cache,
  595. hidden_states=all_hidden_states,
  596. attentions=all_self_attentions,
  597. )
  598. # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
  599. def _update_causal_mask(
  600. self,
  601. attention_mask: torch.Tensor,
  602. input_tensor: torch.Tensor,
  603. cache_position: torch.Tensor,
  604. past_key_values: Cache,
  605. output_attentions: bool,
  606. ):
  607. if self.config._attn_implementation == "flash_attention_2":
  608. if attention_mask is not None and 0.0 in attention_mask:
  609. return attention_mask
  610. return None
  611. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  612. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  613. # to infer the attention mask.
  614. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  615. using_static_cache = isinstance(past_key_values, StaticCache)
  616. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  617. if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
  618. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  619. attention_mask,
  620. inputs_embeds=input_tensor,
  621. past_key_values_length=past_seen_tokens,
  622. is_training=self.training,
  623. ):
  624. return None
  625. dtype, device = input_tensor.dtype, input_tensor.device
  626. sequence_length = input_tensor.shape[1]
  627. if using_static_cache:
  628. target_length = past_key_values.get_max_cache_shape()
  629. else:
  630. target_length = (
  631. attention_mask.shape[-1]
  632. if isinstance(attention_mask, torch.Tensor)
  633. else past_seen_tokens + sequence_length + 1
  634. )
  635. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  636. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  637. attention_mask,
  638. sequence_length=sequence_length,
  639. target_length=target_length,
  640. dtype=dtype,
  641. device=device,
  642. cache_position=cache_position,
  643. batch_size=input_tensor.shape[0],
  644. )
  645. if (
  646. self.config._attn_implementation == "sdpa"
  647. and attention_mask is not None
  648. and attention_mask.device.type == "cuda"
  649. and not output_attentions
  650. ):
  651. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  652. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  653. # Details: https://github.com/pytorch/pytorch/issues/110213
  654. min_dtype = torch.finfo(dtype).min
  655. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  656. return causal_mask
  657. @staticmethod
  658. # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
  659. def _prepare_4d_causal_attention_mask_with_cache_position(
  660. attention_mask: torch.Tensor,
  661. sequence_length: int,
  662. target_length: int,
  663. dtype: torch.dtype,
  664. device: torch.device,
  665. cache_position: torch.Tensor,
  666. batch_size: int,
  667. **kwargs,
  668. ):
  669. """
  670. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  671. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  672. Args:
  673. attention_mask (`torch.Tensor`):
  674. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  675. `(batch_size, 1, query_length, key_value_length)`.
  676. sequence_length (`int`):
  677. The sequence length being processed.
  678. target_length (`int`):
  679. The target length: when generating with static cache, the mask should be as long as the static cache,
  680. to account for the 0 padding, the part of the cache that is not filled yet.
  681. dtype (`torch.dtype`):
  682. The dtype to use for the 4D attention mask.
  683. device (`torch.device`):
  684. The device to plcae the 4D attention mask on.
  685. cache_position (`torch.Tensor`):
  686. Indices depicting the position of the input sequence tokens in the sequence.
  687. batch_size (`torch.Tensor`):
  688. Batch size.
  689. """
  690. if attention_mask is not None and attention_mask.dim() == 4:
  691. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  692. causal_mask = attention_mask
  693. else:
  694. min_dtype = torch.finfo(dtype).min
  695. causal_mask = torch.full(
  696. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  697. )
  698. if sequence_length != 1:
  699. causal_mask = torch.triu(causal_mask, diagonal=1)
  700. causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  701. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  702. if attention_mask is not None:
  703. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  704. mask_length = attention_mask.shape[-1]
  705. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  706. padding_mask = padding_mask == 0
  707. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  708. padding_mask, min_dtype
  709. )
  710. return causal_mask
  711. @add_start_docstrings(
  712. """
  713. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  714. embeddings).
  715. """,
  716. BLOOM_START_DOCSTRING,
  717. )
  718. class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
  719. _tied_weights_keys = ["lm_head.weight"]
  720. def __init__(self, config: BloomConfig):
  721. super().__init__(config)
  722. self.transformer = BloomModel(config)
  723. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  724. # Initialize weights and apply final processing
  725. self.post_init()
  726. def get_output_embeddings(self):
  727. return self.lm_head
  728. def set_output_embeddings(self, new_embeddings: torch.Tensor):
  729. self.lm_head = new_embeddings
  730. def prepare_inputs_for_generation(
  731. self,
  732. input_ids,
  733. past_key_values=None,
  734. attention_mask=None,
  735. inputs_embeds=None,
  736. cache_position=None,
  737. use_cache=True,
  738. **kwargs,
  739. ):
  740. # Overwriten because of the fixed-shape attention mask creation
  741. # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
  742. # Exception 1: when passing input_embeds, input_ids may be missing entries
  743. # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
  744. if past_key_values is not None:
  745. if inputs_embeds is not None: # Exception 1
  746. input_ids = input_ids[:, -cache_position.shape[0] :]
  747. elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
  748. input_ids = input_ids[:, cache_position]
  749. # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
  750. if inputs_embeds is not None and cache_position[0] == 0:
  751. model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
  752. else:
  753. # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
  754. # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
  755. # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
  756. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
  757. # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
  758. # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
  759. if isinstance(past_key_values, StaticCache) and attention_mask is not None:
  760. target_length = past_key_values.get_max_length()
  761. batch_size, seq_length = attention_mask.shape
  762. diff = target_length - seq_length
  763. new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
  764. attention_mask = torch.cat(
  765. [attention_mask, new_attn_mask],
  766. dim=-1,
  767. )
  768. model_inputs.update(
  769. {
  770. "cache_position": cache_position,
  771. "past_key_values": past_key_values,
  772. "use_cache": use_cache,
  773. "attention_mask": attention_mask,
  774. }
  775. )
  776. return model_inputs
  777. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  778. @add_code_sample_docstrings(
  779. checkpoint=_CHECKPOINT_FOR_DOC,
  780. output_type=CausalLMOutputWithCrossAttentions,
  781. config_class=_CONFIG_FOR_DOC,
  782. )
  783. def forward(
  784. self,
  785. input_ids: Optional[torch.LongTensor] = None,
  786. past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  787. attention_mask: Optional[torch.Tensor] = None,
  788. head_mask: Optional[torch.Tensor] = None,
  789. inputs_embeds: Optional[torch.Tensor] = None,
  790. labels: Optional[torch.Tensor] = None,
  791. use_cache: Optional[bool] = None,
  792. output_attentions: Optional[bool] = None,
  793. output_hidden_states: Optional[bool] = None,
  794. return_dict: Optional[bool] = None,
  795. cache_position: Optional[torch.LongTensor] = None,
  796. **deprecated_arguments,
  797. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  798. r"""
  799. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  800. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  801. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  802. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  803. """
  804. if deprecated_arguments.pop("position_ids", False) is not False:
  805. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  806. warnings.warn(
  807. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  808. " passing `position_ids`.",
  809. FutureWarning,
  810. )
  811. if len(deprecated_arguments) > 0:
  812. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  813. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  814. transformer_outputs = self.transformer(
  815. input_ids,
  816. past_key_values=past_key_values,
  817. attention_mask=attention_mask,
  818. head_mask=head_mask,
  819. inputs_embeds=inputs_embeds,
  820. use_cache=use_cache,
  821. output_attentions=output_attentions,
  822. output_hidden_states=output_hidden_states,
  823. return_dict=return_dict,
  824. cache_position=cache_position,
  825. )
  826. hidden_states = transformer_outputs[0]
  827. lm_logits = self.lm_head(hidden_states)
  828. loss = None
  829. if labels is not None:
  830. # move labels to correct device to enable model parallelism
  831. labels = labels.to(lm_logits.device)
  832. # Shift so that tokens < n predict n
  833. shift_logits = lm_logits[..., :-1, :].contiguous()
  834. shift_labels = labels[..., 1:].contiguous()
  835. batch_size, seq_length, vocab_size = shift_logits.shape
  836. # Flatten the tokens
  837. loss_fct = CrossEntropyLoss()
  838. loss = loss_fct(
  839. shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
  840. )
  841. if not return_dict:
  842. output = (lm_logits,) + transformer_outputs[1:]
  843. return ((loss,) + output) if loss is not None else output
  844. return CausalLMOutputWithCrossAttentions(
  845. loss=loss,
  846. logits=lm_logits,
  847. past_key_values=transformer_outputs.past_key_values,
  848. hidden_states=transformer_outputs.hidden_states,
  849. attentions=transformer_outputs.attentions,
  850. )
  851. def _reorder_cache(
  852. self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
  853. ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
  854. """
  855. This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
  856. [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
  857. beam_idx at every generation step.
  858. Output shares the same memory storage as `past`.
  859. """
  860. # Get a copy of `beam_idx` on all the devices where we need those indices.
  861. device_to_beam_idx = {
  862. past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
  863. }
  864. reordered_past = tuple(
  865. (
  866. layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
  867. layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
  868. )
  869. for layer_past in past
  870. )
  871. return reordered_past
  872. @add_start_docstrings(
  873. """
  874. The Bloom Model transformer with a sequence classification head on top (linear layer).
  875. [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  876. (e.g. GPT-1) do.
  877. Since it does classification on the last token, it requires to know the position of the last token. If a
  878. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  879. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  880. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  881. each row of the batch).
  882. """,
  883. BLOOM_START_DOCSTRING,
  884. )
  885. class BloomForSequenceClassification(BloomPreTrainedModel):
  886. def __init__(self, config: BloomConfig):
  887. super().__init__(config)
  888. self.num_labels = config.num_labels
  889. self.transformer = BloomModel(config)
  890. self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
  891. # Initialize weights and apply final processing
  892. self.post_init()
  893. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  894. @add_code_sample_docstrings(
  895. checkpoint=_CHECKPOINT_FOR_DOC,
  896. output_type=SequenceClassifierOutputWithPast,
  897. config_class=_CONFIG_FOR_DOC,
  898. )
  899. def forward(
  900. self,
  901. input_ids: Optional[torch.LongTensor] = None,
  902. past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  903. attention_mask: Optional[torch.Tensor] = None,
  904. head_mask: Optional[torch.Tensor] = None,
  905. inputs_embeds: Optional[torch.Tensor] = None,
  906. labels: Optional[torch.Tensor] = None,
  907. use_cache: Optional[bool] = None,
  908. output_attentions: Optional[bool] = None,
  909. output_hidden_states: Optional[bool] = None,
  910. return_dict: Optional[bool] = None,
  911. **deprecated_arguments,
  912. ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
  913. r"""
  914. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  915. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  916. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  917. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  918. """
  919. if deprecated_arguments.pop("position_ids", False) is not False:
  920. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  921. warnings.warn(
  922. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  923. " passing `position_ids`.",
  924. FutureWarning,
  925. )
  926. if len(deprecated_arguments) > 0:
  927. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  928. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  929. transformer_outputs = self.transformer(
  930. input_ids,
  931. past_key_values=past_key_values,
  932. attention_mask=attention_mask,
  933. head_mask=head_mask,
  934. inputs_embeds=inputs_embeds,
  935. use_cache=use_cache,
  936. output_attentions=output_attentions,
  937. output_hidden_states=output_hidden_states,
  938. return_dict=return_dict,
  939. )
  940. hidden_states = transformer_outputs[0]
  941. logits = self.score(hidden_states)
  942. if input_ids is not None:
  943. batch_size = input_ids.shape[0]
  944. else:
  945. batch_size = inputs_embeds.shape[0]
  946. if self.config.pad_token_id is None and batch_size != 1:
  947. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  948. if self.config.pad_token_id is None:
  949. sequence_lengths = -1
  950. else:
  951. if input_ids is not None:
  952. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  953. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  954. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  955. sequence_lengths = sequence_lengths.to(logits.device)
  956. else:
  957. sequence_lengths = -1
  958. logger.warning_once(
  959. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  960. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  961. )
  962. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  963. loss = None
  964. if labels is not None:
  965. if self.config.problem_type is None:
  966. if self.num_labels == 1:
  967. self.config.problem_type = "regression"
  968. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  969. self.config.problem_type = "single_label_classification"
  970. else:
  971. self.config.problem_type = "multi_label_classification"
  972. if self.config.problem_type == "regression":
  973. loss_fct = MSELoss()
  974. if self.num_labels == 1:
  975. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  976. else:
  977. loss = loss_fct(pooled_logits, labels)
  978. elif self.config.problem_type == "single_label_classification":
  979. loss_fct = CrossEntropyLoss()
  980. loss = loss_fct(pooled_logits, labels)
  981. elif self.config.problem_type == "multi_label_classification":
  982. loss_fct = BCEWithLogitsLoss()
  983. loss = loss_fct(pooled_logits, labels)
  984. if not return_dict:
  985. output = (pooled_logits,) + transformer_outputs[1:]
  986. return ((loss,) + output) if loss is not None else output
  987. return SequenceClassifierOutputWithPast(
  988. loss=loss,
  989. logits=pooled_logits,
  990. past_key_values=transformer_outputs.past_key_values,
  991. hidden_states=transformer_outputs.hidden_states,
  992. attentions=transformer_outputs.attentions,
  993. )
  994. @add_start_docstrings(
  995. """
  996. Bloom Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  997. Named-Entity-Recognition (NER) tasks.
  998. """,
  999. BLOOM_START_DOCSTRING,
  1000. )
  1001. class BloomForTokenClassification(BloomPreTrainedModel):
  1002. def __init__(self, config: BloomConfig):
  1003. super().__init__(config)
  1004. self.num_labels = config.num_labels
  1005. self.transformer = BloomModel(config)
  1006. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  1007. classifier_dropout = config.classifier_dropout
  1008. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  1009. classifier_dropout = config.hidden_dropout
  1010. else:
  1011. classifier_dropout = 0.1
  1012. self.dropout = nn.Dropout(classifier_dropout)
  1013. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1014. # Initialize weights and apply final processing
  1015. self.post_init()
  1016. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  1017. @add_code_sample_docstrings(
  1018. checkpoint=_CHECKPOINT_FOR_DOC,
  1019. output_type=TokenClassifierOutput,
  1020. config_class=_CONFIG_FOR_DOC,
  1021. )
  1022. def forward(
  1023. self,
  1024. input_ids: Optional[torch.LongTensor] = None,
  1025. past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None,
  1026. attention_mask: Optional[torch.Tensor] = None,
  1027. head_mask: Optional[torch.Tensor] = None,
  1028. inputs_embeds: Optional[torch.Tensor] = None,
  1029. labels: Optional[torch.Tensor] = None,
  1030. use_cache: Optional[bool] = None,
  1031. output_attentions: Optional[bool] = None,
  1032. output_hidden_states: Optional[bool] = None,
  1033. return_dict: Optional[bool] = None,
  1034. **deprecated_arguments,
  1035. ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
  1036. r"""
  1037. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1038. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1039. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1040. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1041. """
  1042. if deprecated_arguments.pop("position_ids", False) is not False:
  1043. # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
  1044. warnings.warn(
  1045. "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
  1046. " passing `position_ids`.",
  1047. FutureWarning,
  1048. )
  1049. if len(deprecated_arguments) > 0:
  1050. raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
  1051. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1052. transformer_outputs = self.transformer(
  1053. input_ids,
  1054. past_key_values=past_key_values,
  1055. attention_mask=attention_mask,
  1056. head_mask=head_mask,
  1057. inputs_embeds=inputs_embeds,
  1058. use_cache=use_cache,
  1059. output_attentions=output_attentions,
  1060. output_hidden_states=output_hidden_states,
  1061. return_dict=return_dict,
  1062. )
  1063. hidden_states = transformer_outputs[0]
  1064. hidden_states = self.dropout(hidden_states)
  1065. logits = self.classifier(hidden_states)
  1066. loss = None
  1067. if labels is not None:
  1068. # move labels to correct device to enable model parallelism
  1069. labels = labels.to(logits.device)
  1070. batch_size, seq_length = labels.shape
  1071. loss_fct = CrossEntropyLoss()
  1072. loss = loss_fct(
  1073. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  1074. )
  1075. if not return_dict:
  1076. output = (logits,) + transformer_outputs[2:]
  1077. return ((loss,) + output) if loss is not None else output
  1078. return TokenClassifierOutput(
  1079. loss=loss,
  1080. logits=logits,
  1081. hidden_states=transformer_outputs.hidden_states,
  1082. attentions=transformer_outputs.attentions,
  1083. )
  1084. @add_start_docstrings(
  1085. """
  1086. The BLOOM Model transformer with a span classification head on top for extractive question-answering tasks like
  1087. SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1088. """,
  1089. BLOOM_START_DOCSTRING,
  1090. )
  1091. class BloomForQuestionAnswering(BloomPreTrainedModel):
  1092. def __init__(self, config):
  1093. super().__init__(config)
  1094. self.transformer = BloomModel(config)
  1095. self.qa_outputs = nn.Linear(config.hidden_size, 2)
  1096. # Initialize weights and apply final processing
  1097. self.post_init()
  1098. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  1099. def forward(
  1100. self,
  1101. input_ids: Optional[torch.LongTensor] = None,
  1102. attention_mask: Optional[torch.FloatTensor] = None,
  1103. position_ids: Optional[torch.LongTensor] = None,
  1104. head_mask: Optional[torch.FloatTensor] = None,
  1105. inputs_embeds: Optional[torch.FloatTensor] = None,
  1106. start_positions: Optional[torch.LongTensor] = None,
  1107. end_positions: Optional[torch.LongTensor] = None,
  1108. output_attentions: Optional[bool] = None,
  1109. output_hidden_states: Optional[bool] = None,
  1110. return_dict: Optional[bool] = None,
  1111. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1112. r"""
  1113. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1114. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1115. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1116. are not taken into account for computing the loss.
  1117. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1118. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1119. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1120. are not taken into account for computing the loss.
  1121. """
  1122. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1123. outputs = self.transformer(
  1124. input_ids,
  1125. attention_mask=attention_mask,
  1126. position_ids=position_ids,
  1127. head_mask=head_mask,
  1128. inputs_embeds=inputs_embeds,
  1129. output_attentions=output_attentions,
  1130. output_hidden_states=output_hidden_states,
  1131. return_dict=return_dict,
  1132. )
  1133. sequence_output = outputs[0]
  1134. logits = self.qa_outputs(sequence_output)
  1135. start_logits, end_logits = logits.split(1, dim=-1)
  1136. start_logits = start_logits.squeeze(-1).contiguous()
  1137. end_logits = end_logits.squeeze(-1).contiguous()
  1138. total_loss = None
  1139. if start_positions is not None and end_positions is not None:
  1140. # If we are on multi-GPU, split add a dimension
  1141. if len(start_positions.size()) > 1:
  1142. start_positions = start_positions.squeeze(-1)
  1143. if len(end_positions.size()) > 1:
  1144. end_positions = end_positions.squeeze(-1)
  1145. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1146. ignored_index = start_logits.size(1)
  1147. start_positions = start_positions.clamp(0, ignored_index)
  1148. end_positions = end_positions.clamp(0, ignored_index)
  1149. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1150. start_loss = loss_fct(start_logits, start_positions)
  1151. end_loss = loss_fct(end_logits, end_positions)
  1152. total_loss = (start_loss + end_loss) / 2
  1153. if not return_dict:
  1154. output = (start_logits, end_logits) + outputs[2:]
  1155. return ((total_loss,) + output) if total_loss is not None else output
  1156. return QuestionAnsweringModelOutput(
  1157. loss=total_loss,
  1158. start_logits=start_logits,
  1159. end_logits=end_logits,
  1160. hidden_states=outputs.hidden_states,
  1161. attentions=outputs.attentions,
  1162. )