modeling_flax_bloom.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. # coding=utf-8
  2. # Copyright 2023 HuggingFace Inc. Team and Bigscience Workshop. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Flax BLOOM model."""
  16. import math
  17. from functools import partial
  18. from typing import Optional, Tuple
  19. import flax.linen as nn
  20. import jax
  21. import jax.numpy as jnp
  22. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  23. from flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask
  24. from flax.linen.activation import tanh
  25. from flax.traverse_util import flatten_dict, unflatten_dict
  26. from jax import lax
  27. from ...modeling_flax_outputs import (
  28. FlaxBaseModelOutput,
  29. FlaxBaseModelOutputWithPastAndCrossAttentions,
  30. FlaxCausalLMOutput,
  31. )
  32. from ...modeling_flax_utils import FlaxPreTrainedModel, append_call_sample_docstring
  33. from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
  34. from .configuration_bloom import BloomConfig
  35. logger = logging.get_logger(__name__)
  36. _CHECKPOINT_FOR_DOC = "bigscience/bloom"
  37. _CONFIG_FOR_DOC = "BloomConfig"
  38. BLOOM_START_DOCSTRING = r"""
  39. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  40. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  41. etc.)
  42. This model is also a Flax Linen
  43. [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
  44. regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
  45. Finally, this model supports inherent JAX features such as:
  46. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  47. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  48. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  49. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  50. Parameters:
  51. config ([`BloomConfig`]): Model configuration class with all the parameters of the model.
  52. Initializing with a config file does not load the weights associated with the model, only the
  53. configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
  54. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
  55. The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
  56. `jax.numpy.bfloat16` (on TPUs).
  57. This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
  58. specified all the computation will be performed with the given `dtype`.
  59. **Note that this only specifies the dtype of the computation and does not influence the dtype of model
  60. parameters.**
  61. If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
  62. [`~FlaxPreTrainedModel.to_bf16`].
  63. """
  64. BLOOM_INPUTS_DOCSTRING = r"""
  65. Args:
  66. input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
  67. `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
  68. Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  69. [`PreTrainedTokenizer.__call__`] for details.
  70. [What are input IDs?](../glossary#input-ids)
  71. attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  72. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  73. - 1 for tokens that are **not masked**,
  74. - 0 for tokens that are **masked**.
  75. [What are attention masks?](../glossary#attention-mask)
  76. past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
  77. Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
  78. auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
  79. output_attentions (`bool`, *optional*):
  80. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  81. tensors for more detail.
  82. output_hidden_states (`bool`, *optional*):
  83. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  84. more detail.
  85. return_dict (`bool`, *optional*):
  86. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  87. """
  88. def build_alibi_tensor(attention_mask: jnp.ndarray, num_heads: int, dtype: Optional[jnp.dtype] = jnp.float32):
  89. """
  90. Flax implementation of the BLOOM Alibi tensor. BLOOM Alibi tensor is not causal as the original paper mentions, it
  91. relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
  92. `softmax(l+a) = softmax(l)`. Based on
  93. https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
  94. Link to paper: https://arxiv.org/abs/2108.12409
  95. Args:
  96. attention_mask (`jnp.ndarray`):
  97. Token-wise attention mask, this should be of shape `(batch_size, max_seq_len)`.
  98. num_heads (`int`):
  99. Number of attention heads.
  100. dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
  101. The data type (dtype) of the output tensor.
  102. Returns: Alibi tensor of shape `(batch_size * num_heads, 1, max_seq_len)`.
  103. """
  104. batch_size, seq_length = attention_mask.shape
  105. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
  106. base = jnp.array(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=jnp.float32)
  107. powers = jnp.arange(1, 1 + closest_power_of_2, dtype=jnp.float32)
  108. slopes = jax.lax.pow(base, powers)
  109. if closest_power_of_2 != num_heads:
  110. extra_base = jnp.array(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=jnp.float32)
  111. num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
  112. extra_powers = jnp.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=jnp.float32)
  113. slopes = jnp.cat([slopes, jax.lax.pow(extra_base, extra_powers)], axis=0)
  114. # Note: the Alibi tensor will added to the attention bias that will be applied to the query, key product of attention
  115. # therefore, Alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
  116. # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
  117. # so that the query_length dimension will then be broadcast correctly.
  118. # This is more or less identical to T5's relative position bias:
  119. # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
  120. arange_tensor = ((attention_mask.cumsum(axis=-1) - 1) * attention_mask)[:, None, :]
  121. alibi = slopes[..., None] * arange_tensor
  122. alibi = jnp.expand_dims(alibi, axis=2)
  123. return jnp.asarray(alibi, dtype)
  124. class FlaxBloomAttention(nn.Module):
  125. config: BloomConfig
  126. dtype: jnp.dtype = jnp.float32
  127. def setup(self):
  128. self.hidden_size = self.config.hidden_size
  129. self.num_heads = self.config.n_head
  130. self.head_dim = self.hidden_size // self.num_heads
  131. self.attention_softmax_in_fp32 = self.dtype is not jnp.float32
  132. if self.head_dim * self.num_heads != self.hidden_size:
  133. raise ValueError(
  134. f"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and "
  135. f"`num_heads`: {self.num_heads})."
  136. )
  137. dense = partial(
  138. nn.Dense,
  139. dtype=self.dtype,
  140. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  141. )
  142. self.query_key_value = dense(self.hidden_size * 3)
  143. self.dense = dense(self.hidden_size)
  144. self.resid_dropout = nn.Dropout(rate=self.config.hidden_dropout)
  145. def _split_heads(self, hidden_states):
  146. return hidden_states.reshape(hidden_states.shape[:-1] + (self.num_heads, self.head_dim * 3))
  147. def _merge_heads(self, hidden_states):
  148. return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
  149. @nn.compact
  150. # Copied from transformers.models.gptj.modeling_flax_gptj.FlaxGPTJAttention._concatenate_to_cache
  151. def _concatenate_to_cache(self, key, value, query, attention_mask):
  152. """
  153. This function takes projected key, value states from a single input token and concatenates the states to cached
  154. states from previous steps. This function is slighly adapted from the official Flax repository:
  155. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  156. """
  157. # detect if we're initializing by absence of existing cache data.
  158. is_initialized = self.has_variable("cache", "cached_key")
  159. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  160. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  161. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  162. if is_initialized:
  163. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  164. # update key, value caches with our new 1d spatial slices
  165. cur_index = cache_index.value
  166. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  167. key = lax.dynamic_update_slice(cached_key.value, key, indices)
  168. value = lax.dynamic_update_slice(cached_value.value, value, indices)
  169. cached_key.value = key
  170. cached_value.value = value
  171. num_updated_cache_vectors = query.shape[1]
  172. cache_index.value = cache_index.value + num_updated_cache_vectors
  173. # causal mask for cached decoder self-attention: our single query position should only attend to those key
  174. # positions that have already been generated and cached, not the remaining zero elements.
  175. pad_mask = jnp.broadcast_to(
  176. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  177. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  178. )
  179. attention_mask = combine_masks(pad_mask, attention_mask)
  180. return key, value, attention_mask
  181. def __call__(
  182. self,
  183. hidden_states,
  184. residual,
  185. alibi,
  186. attention_mask=None,
  187. deterministic: bool = True,
  188. init_cache: bool = False,
  189. output_attentions: bool = False,
  190. ):
  191. batch_size, seq_length = hidden_states.shape[:2]
  192. # proj q, k, v
  193. fused_qkv = self.query_key_value(hidden_states)
  194. fused_qkv = self._split_heads(fused_qkv)
  195. query, key, value = jnp.split(fused_qkv, 3, axis=-1)
  196. causal_attention_mask = make_causal_mask(attention_mask, dtype="bool")
  197. # for fast decoding causal attention mask should be shifted
  198. causal_attention_mask_shift = (
  199. self.variables["cache"]["cache_index"] if self.has_variable("cache", "cached_key") else 0
  200. )
  201. # fast decoding for generate requires special attention_mask
  202. if self.has_variable("cache", "cached_key"):
  203. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  204. causal_attention_mask = jax.lax.dynamic_slice(
  205. causal_attention_mask,
  206. (0, 0, causal_attention_mask_shift, 0),
  207. (1, 1, seq_length, max_decoder_length),
  208. )
  209. # broadcast causal attention mask & attention mask to fit for merge
  210. causal_attention_mask = jnp.broadcast_to(
  211. causal_attention_mask, (batch_size,) + causal_attention_mask.shape[1:]
  212. )
  213. attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_attention_mask.shape)
  214. attention_mask = combine_masks(attention_mask, causal_attention_mask)
  215. dropout_rng = None
  216. if not deterministic and self.config.attention_dropout > 0.0:
  217. dropout_rng = self.make_rng("dropout")
  218. # During fast autoregressive decoding, we feed one position at a time,
  219. # and cache the keys and values step by step.
  220. if self.has_variable("cache", "cached_key") or init_cache:
  221. key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
  222. # transform boolean mask into float mask
  223. mask_value = jnp.finfo(self.dtype).min
  224. attention_bias = lax.select(
  225. attention_mask > 0,
  226. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  227. jnp.full(attention_mask.shape, mask_value).astype(self.dtype),
  228. )
  229. attention_bias = attention_bias + alibi
  230. # Cast in fp32 if the original dtype is different from fp32
  231. attention_dtype = jnp.float32 if self.attention_softmax_in_fp32 else self.dtype
  232. attn_weights = dot_product_attention_weights(
  233. query,
  234. key,
  235. bias=attention_bias,
  236. dropout_rng=dropout_rng,
  237. dropout_rate=self.config.attention_dropout,
  238. deterministic=deterministic,
  239. dtype=attention_dtype,
  240. )
  241. # Cast back in the original dtype if the native dtype is not fp32
  242. if self.attention_softmax_in_fp32:
  243. attn_weights = attn_weights.astype(self.dtype)
  244. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
  245. attn_output = self._merge_heads(attn_output)
  246. attn_output = self.dense(attn_output)
  247. attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
  248. attn_output = attn_output + residual
  249. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  250. return outputs
  251. class BloomGELU(nn.Module):
  252. def setup(self):
  253. self.dtype = jnp.float32
  254. def __call__(self, x):
  255. return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
  256. class FlaxBloomMLP(nn.Module):
  257. config: BloomConfig
  258. dtype: jnp.dtype = jnp.float32
  259. def setup(self):
  260. hidden_size = self.config.hidden_size
  261. kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
  262. self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)
  263. self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)
  264. self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)
  265. self.act = BloomGELU()
  266. def __call__(self, hidden_states, residual, deterministic: bool = True):
  267. hidden_states = self.dense_h_to_4h(hidden_states)
  268. hidden_states = self.act(hidden_states)
  269. intermediate_output = self.dense_4h_to_h(hidden_states)
  270. intermediate_output = intermediate_output + residual
  271. hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)
  272. return hidden_states
  273. class FlaxBloomBlock(nn.Module):
  274. config: BloomConfig
  275. dtype: jnp.dtype = jnp.float32
  276. def setup(self):
  277. self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
  278. self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)
  279. self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
  280. self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)
  281. self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm
  282. self.hidden_dropout = self.config.hidden_dropout
  283. def __call__(
  284. self,
  285. hidden_states,
  286. alibi,
  287. attention_mask=None,
  288. deterministic: bool = True,
  289. init_cache: bool = False,
  290. output_attentions: bool = False,
  291. ):
  292. layernorm_output = self.input_layernorm(hidden_states)
  293. # layer norm before saving residual if config calls for it
  294. if self.apply_residual_connection_post_layernorm:
  295. residual = layernorm_output
  296. else:
  297. residual = hidden_states
  298. # self-attention
  299. attn_outputs = self.self_attention(
  300. layernorm_output,
  301. residual=residual,
  302. alibi=alibi,
  303. attention_mask=attention_mask,
  304. deterministic=deterministic,
  305. init_cache=init_cache,
  306. output_attentions=output_attentions,
  307. )
  308. attention_output = attn_outputs[0]
  309. outputs = attn_outputs[1:]
  310. post_layernorm = self.post_attention_layernorm(attention_output)
  311. # set residual based on config
  312. if self.apply_residual_connection_post_layernorm:
  313. residual = post_layernorm
  314. else:
  315. residual = attention_output
  316. output = self.mlp(post_layernorm, residual, deterministic=deterministic)
  317. outputs = (output,) + outputs
  318. return outputs
  319. class FlaxBloomPreTrainedModel(FlaxPreTrainedModel):
  320. """
  321. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  322. models.
  323. """
  324. config_class = BloomConfig
  325. base_model_prefix = "transformer"
  326. module_class: nn.Module = None
  327. def __init__(
  328. self,
  329. config: BloomConfig,
  330. input_shape: Tuple = (1, 1),
  331. seed: int = 0,
  332. dtype: jnp.dtype = jnp.float32,
  333. _do_init: bool = True,
  334. **kwargs,
  335. ):
  336. module = self.module_class(config=config, dtype=dtype, **kwargs)
  337. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  338. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  339. # init input tensors
  340. input_ids = jnp.zeros(input_shape, dtype="i4")
  341. attention_mask = jnp.ones_like(input_ids)
  342. params_rng, dropout_rng = jax.random.split(rng)
  343. rngs = {"params": params_rng, "dropout": dropout_rng}
  344. random_params = self.module.init(rngs, input_ids, attention_mask, return_dict=False)["params"]
  345. if params is not None:
  346. random_params = flatten_dict(unfreeze(random_params))
  347. params = flatten_dict(unfreeze(params))
  348. for missing_key in self._missing_keys:
  349. params[missing_key] = random_params[missing_key]
  350. self._missing_keys = set()
  351. return freeze(unflatten_dict(params))
  352. else:
  353. return random_params
  354. def init_cache(self, batch_size, max_length):
  355. r"""
  356. Args:
  357. batch_size (`int`):
  358. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  359. max_length (`int`):
  360. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  361. cache.
  362. """
  363. # init input variables to retrieve cache
  364. input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  365. attention_mask = jnp.ones_like(input_ids)
  366. init_variables = self.module.init(
  367. jax.random.PRNGKey(0), input_ids, attention_mask, return_dict=False, init_cache=True
  368. )
  369. return unfreeze(init_variables["cache"])
  370. @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
  371. def __call__(
  372. self,
  373. input_ids,
  374. attention_mask=None,
  375. past_key_values: dict = None,
  376. params: dict = None,
  377. dropout_rng: jax.random.PRNGKey = None,
  378. train: bool = False,
  379. output_attentions: Optional[bool] = None,
  380. output_hidden_states: Optional[bool] = None,
  381. return_dict: Optional[bool] = None,
  382. ):
  383. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  384. output_hidden_states = (
  385. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  386. )
  387. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  388. batch_size, sequence_length = input_ids.shape
  389. if attention_mask is None:
  390. attention_mask = jnp.ones((batch_size, sequence_length))
  391. # Handle any PRNG if needed
  392. rngs = {}
  393. if dropout_rng is not None:
  394. rngs["dropout"] = dropout_rng
  395. inputs = {"params": params or self.params}
  396. # If past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
  397. # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
  398. # changed by FlaxBloomAttention module
  399. if past_key_values:
  400. inputs["cache"] = past_key_values
  401. mutable = ["cache"]
  402. else:
  403. mutable = False
  404. outputs = self.module.apply(
  405. inputs,
  406. jnp.array(input_ids, dtype="i4"),
  407. jnp.array(attention_mask, dtype="i4"),
  408. not train,
  409. False,
  410. output_attentions,
  411. output_hidden_states,
  412. return_dict,
  413. rngs=rngs,
  414. mutable=mutable,
  415. )
  416. # add updated cache to model output
  417. if past_key_values is not None and return_dict:
  418. outputs, past_key_values = outputs
  419. outputs["past_key_values"] = unfreeze(past_key_values["cache"])
  420. return outputs
  421. elif past_key_values is not None and not return_dict:
  422. outputs, past_key_values = outputs
  423. outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
  424. return outputs
  425. class FlaxBloomBlockCollection(nn.Module):
  426. config: BloomConfig
  427. dtype: jnp.dtype = jnp.float32
  428. def setup(self):
  429. self.layers = [
  430. FlaxBloomBlock(self.config, name=str(layer_number), dtype=self.dtype)
  431. for layer_number in range(self.config.num_hidden_layers)
  432. ]
  433. def __call__(
  434. self,
  435. hidden_states,
  436. alibi,
  437. attention_mask=None,
  438. deterministic: bool = True,
  439. init_cache: bool = False,
  440. output_attentions: bool = False,
  441. output_hidden_states: bool = False,
  442. ):
  443. all_attentions = () if output_attentions else None
  444. all_hidden_states = () if output_hidden_states else None
  445. for layer_number in range(self.config.num_hidden_layers):
  446. if output_hidden_states:
  447. all_hidden_states += (hidden_states,)
  448. layer_outputs = self.layers[layer_number](
  449. hidden_states,
  450. alibi=alibi,
  451. attention_mask=attention_mask,
  452. deterministic=deterministic,
  453. init_cache=init_cache,
  454. output_attentions=output_attentions,
  455. )
  456. hidden_states = layer_outputs[0]
  457. if output_attentions:
  458. all_attentions += (layer_outputs[1],)
  459. # this contains possible `None` values - `FlaxBloomModule` will filter them out
  460. outputs = (hidden_states, all_hidden_states, all_attentions)
  461. return outputs
  462. class FlaxBloomModule(nn.Module):
  463. config: BloomConfig
  464. dtype: jnp.dtype = jnp.float32
  465. def setup(self):
  466. self.embed_dim = self.config.hidden_size
  467. # word embeddings (no positional embedding layer)
  468. self.word_embeddings = nn.Embed(
  469. self.config.vocab_size,
  470. self.embed_dim,
  471. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  472. dtype=self.dtype,
  473. )
  474. # post-embedding layernorm
  475. self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
  476. # transformer layers
  477. self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)
  478. # final layernorm
  479. self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
  480. def __call__(
  481. self,
  482. input_ids=None,
  483. attention_mask=None,
  484. deterministic=True,
  485. init_cache: bool = False,
  486. output_attentions: bool = False,
  487. output_hidden_states: bool = False,
  488. return_dict: bool = True,
  489. ):
  490. inputs_embeds = self.word_embeddings(input_ids)
  491. # do post-embedding layernorm
  492. hidden_states = self.word_embeddings_layernorm(inputs_embeds)
  493. # build alibi depending on `attention_mask`
  494. alibi = build_alibi_tensor(attention_mask, self.config.n_head, dtype=hidden_states.dtype)
  495. outputs = self.h(
  496. hidden_states,
  497. alibi=alibi,
  498. attention_mask=attention_mask,
  499. deterministic=deterministic,
  500. init_cache=init_cache,
  501. output_hidden_states=output_hidden_states,
  502. output_attentions=output_attentions,
  503. )
  504. hidden_states = outputs[0]
  505. hidden_states = self.ln_f(hidden_states)
  506. if output_hidden_states:
  507. all_hidden_states = outputs[1] + (hidden_states,)
  508. outputs = (hidden_states, all_hidden_states) + outputs[2:]
  509. else:
  510. outputs = (hidden_states,) + outputs[1:]
  511. if not return_dict:
  512. return tuple(v for v in [outputs[0], outputs[-1]] if v is not None)
  513. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  514. last_hidden_state=hidden_states,
  515. hidden_states=outputs[1],
  516. attentions=outputs[-1],
  517. )
  518. @add_start_docstrings(
  519. "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
  520. BLOOM_START_DOCSTRING,
  521. )
  522. # Copied from transformers.models.gpt_neo.modeling_flax_gpt_neo.FlaxGPTNeoModel with GPTNeo->Bloom
  523. class FlaxBloomModel(FlaxBloomPreTrainedModel):
  524. module_class = FlaxBloomModule
  525. append_call_sample_docstring(FlaxBloomModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
  526. class FlaxBloomForCausalLMModule(nn.Module):
  527. config: BloomConfig
  528. dtype: jnp.dtype = jnp.float32
  529. def setup(self):
  530. self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)
  531. self.lm_head = nn.Dense(
  532. self.config.vocab_size,
  533. use_bias=False,
  534. dtype=self.dtype,
  535. kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  536. )
  537. def __call__(
  538. self,
  539. input_ids,
  540. attention_mask,
  541. deterministic: bool = True,
  542. init_cache: bool = False,
  543. output_attentions: bool = False,
  544. output_hidden_states: bool = False,
  545. return_dict: bool = True,
  546. ):
  547. outputs = self.transformer(
  548. input_ids,
  549. attention_mask=attention_mask,
  550. deterministic=deterministic,
  551. init_cache=init_cache,
  552. output_attentions=output_attentions,
  553. output_hidden_states=output_hidden_states,
  554. return_dict=return_dict,
  555. )
  556. hidden_states = outputs[0]
  557. if self.config.tie_word_embeddings:
  558. shared_kernel = self.transformer.variables["params"]["word_embeddings"]["embedding"].T
  559. lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
  560. else:
  561. lm_logits = self.lm_head(hidden_states)
  562. if not return_dict:
  563. return (lm_logits,) + outputs[1:]
  564. return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
  565. @add_start_docstrings(
  566. """
  567. The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
  568. embeddings).
  569. """,
  570. BLOOM_START_DOCSTRING,
  571. )
  572. class FlaxBloomForCausalLM(FlaxBloomPreTrainedModel):
  573. module_class = FlaxBloomForCausalLMModule
  574. def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
  575. # initializing the cache
  576. batch_size, seq_length = input_ids.shape
  577. past_key_values = self.init_cache(batch_size, max_length)
  578. # Note that usually one would have to put 0's in the attention_mask for
  579. # x > input_ids.shape[-1] and x < cache_length. But since Bloom uses a causal mask,
  580. # those positions are masked anyway. Thus, we can create a single static attention_mask here,
  581. # which is more efficient for compilation
  582. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  583. if attention_mask is not None:
  584. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
  585. return {
  586. "past_key_values": past_key_values,
  587. "attention_mask": extended_attention_mask,
  588. }
  589. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  590. model_kwargs["past_key_values"] = model_outputs.past_key_values
  591. return model_kwargs
  592. append_call_sample_docstring(FlaxBloomForCausalLM, _CHECKPOINT_FOR_DOC, FlaxCausalLMOutput, _CONFIG_FOR_DOC)