| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799 |
- # coding=utf-8
- # Copyright 2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Flax OPT model."""
- from functools import partial
- from typing import Optional, Tuple
- import flax.linen as nn
- import jax
- import jax.numpy as jnp
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
- from flax.linen import combine_masks, make_causal_mask
- from flax.linen.attention import dot_product_attention_weights
- from flax.traverse_util import flatten_dict, unflatten_dict
- from jax import lax
- from jax.random import PRNGKey
- from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxMaskedLMOutput
- from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
- from ...utils import add_start_docstrings, logging
- from .configuration_opt import OPTConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
- _CONFIG_FOR_DOC = "OPTConfig"
- OPT_START_DOCSTRING = r"""
- This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a Flax Linen
- [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
- regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
- Finally, this model supports inherent JAX features such as:
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
- Parameters:
- config ([`OPTConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
- `jax.numpy.bfloat16` (on TPUs).
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
- specified all the computation will be performed with the given `dtype`.
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
- parameters.**
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
- [`~FlaxPreTrainedModel.to_bf16`].
- """
- OPT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention with Bart->OPT
- class FlaxOPTAttention(nn.Module):
- config: OPTConfig
- embed_dim: int
- num_heads: int
- dropout: float = 0.0
- causal: bool = False
- bias: bool = True
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self) -> None:
- self.head_dim = self.embed_dim // self.num_heads
- if self.head_dim * self.num_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
- f" and `num_heads`: {self.num_heads})."
- )
- dense = partial(
- nn.Dense,
- self.embed_dim,
- use_bias=self.bias,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
- )
- self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
- self.out_proj = dense()
- self.dropout_layer = nn.Dropout(rate=self.dropout)
- if self.causal:
- self.causal_mask = make_causal_mask(
- jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
- )
- def _split_heads(self, hidden_states):
- return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
- def _merge_heads(self, hidden_states):
- return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
- @nn.compact
- def _concatenate_to_cache(self, key, value, query, attention_mask):
- """
- This function takes projected key, value states from a single input token and concatenates the states to cached
- states from previous steps. This function is slighly adapted from the official Flax repository:
- https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
- """
- # detect if we're initializing by absence of existing cache data.
- is_initialized = self.has_variable("cache", "cached_key")
- cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
- cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
- cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
- if is_initialized:
- *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
- # update key, value caches with our new 1d spatial slices
- cur_index = cache_index.value
- indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
- key = lax.dynamic_update_slice(cached_key.value, key, indices)
- value = lax.dynamic_update_slice(cached_value.value, value, indices)
- cached_key.value = key
- cached_value.value = value
- num_updated_cache_vectors = query.shape[1]
- cache_index.value = cache_index.value + num_updated_cache_vectors
- # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
- pad_mask = jnp.broadcast_to(
- jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
- tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
- )
- attention_mask = combine_masks(pad_mask, attention_mask)
- return key, value, attention_mask
- def __call__(
- self,
- hidden_states: jnp.ndarray,
- key_value_states: Optional[jnp.ndarray] = None,
- attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- ) -> Tuple[jnp.ndarray]:
- """Input shape: Batch x Time x Channel"""
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- batch_size = hidden_states.shape[0]
- # get query proj
- query_states = self.q_proj(hidden_states)
- # get key, value proj
- if is_cross_attention:
- # cross_attentions
- key_states = self.k_proj(key_value_states)
- value_states = self.v_proj(key_value_states)
- else:
- # self_attention
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- query_states = self._split_heads(query_states)
- key_states = self._split_heads(key_states)
- value_states = self._split_heads(value_states)
- # handle cache prepare causal attention mask
- if self.causal:
- query_length, key_length = query_states.shape[1], key_states.shape[1]
- if self.has_variable("cache", "cached_key"):
- mask_shift = self.variables["cache"]["cache_index"]
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
- causal_mask = lax.dynamic_slice(
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
- )
- else:
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
- # combine masks if needed
- if attention_mask is not None and self.causal:
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
- attention_mask = combine_masks(attention_mask, causal_mask)
- elif self.causal:
- attention_mask = causal_mask
- elif attention_mask is not None:
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
- # During fast autoregressive decoding, we feed one position at a time,
- # and cache the keys and values step by step.
- if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
- key_states, value_states, attention_mask = self._concatenate_to_cache(
- key_states, value_states, query_states, attention_mask
- )
- # Convert the boolean attention mask to an attention bias.
- if attention_mask is not None:
- # attention mask in the form of attention bias
- attention_bias = lax.select(
- attention_mask > 0,
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
- )
- else:
- attention_bias = None
- dropout_rng = None
- if not deterministic and self.dropout > 0.0:
- dropout_rng = self.make_rng("dropout")
- attn_weights = dot_product_attention_weights(
- query_states,
- key_states,
- bias=attention_bias,
- dropout_rng=dropout_rng,
- dropout_rate=self.dropout,
- broadcast_dropout=True,
- deterministic=deterministic,
- dtype=self.dtype,
- precision=None,
- )
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
- attn_output = self._merge_heads(attn_output)
- attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights
- class FlaxOPTDecoderLayer(nn.Module):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32
- def setup(self) -> None:
- self.embed_dim = self.config.hidden_size
- self.self_attn = FlaxOPTAttention(
- config=self.config,
- embed_dim=self.embed_dim,
- num_heads=self.config.num_attention_heads,
- dropout=self.config.attention_dropout,
- causal=True,
- dtype=self.dtype,
- )
- self.do_layer_norm_before = self.config.do_layer_norm_before
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
- self.activation_fn = ACT2FN[self.config.activation_function]
- self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
- self.fc1 = nn.Dense(
- self.config.ffn_dim,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
- )
- self.fc2 = nn.Dense(
- self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
- )
- self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
- def __call__(
- self,
- hidden_states: jnp.ndarray,
- attention_mask: jnp.ndarray,
- init_cache: bool = False,
- output_attentions: bool = True,
- deterministic: bool = True,
- ) -> Tuple[jnp.ndarray]:
- residual = hidden_states
- # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
- if self.do_layer_norm_before:
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Self Attention
- hidden_states, self_attn_weights = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- )
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
- hidden_states = residual + hidden_states
- # 350m applies layer norm AFTER attention
- if not self.do_layer_norm_before:
- hidden_states = self.self_attn_layer_norm(hidden_states)
- # Fully Connected
- hidden_states_shape = hidden_states.shape
- hidden_states = hidden_states.reshape(-1, hidden_states.shape[-1])
- residual = hidden_states
- # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
- if self.do_layer_norm_before:
- hidden_states = self.final_layer_norm(hidden_states)
- hidden_states = self.fc1(hidden_states)
- hidden_states = self.activation_fn(hidden_states)
- hidden_states = self.fc2(hidden_states)
- hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
- hidden_states = (residual + hidden_states).reshape(hidden_states_shape)
- # 350m applies layer norm AFTER attention
- if not self.do_layer_norm_before:
- hidden_states = self.final_layer_norm(hidden_states)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- return outputs
- class FlaxOPTDecoderLayerCollection(nn.Module):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.layers = [
- FlaxOPTDecoderLayer(self.config, name=str(i), dtype=self.dtype)
- for i in range(self.config.num_hidden_layers)
- ]
- self.layerdrop = self.config.layerdrop
- def __call__(
- self,
- hidden_states,
- attention_mask,
- deterministic: bool = True,
- init_cache: bool = False,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- ):
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- init_cache=init_cache,
- output_attentions=output_attentions,
- deterministic=deterministic,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- outputs = [hidden_states, all_hidden_states, all_self_attns]
- return outputs
- class FlaxOPTLearnedPositionalEmbedding(nn.Embed):
- """
- This module learns positional embeddings up to a fixed maximum size.
- """
- def setup(self):
- self.offset = 2
- self.embedding = self.param(
- "embedding", self.embedding_init, (self.num_embeddings + self.offset, self.features), self.param_dtype
- )
- def __call__(self, positions):
- """`input_ids_shape` is expected to be [bsz x seqlen]."""
- return super().__call__(positions + self.offset)
- class FlaxOPTDecoder(nn.Module):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- offset: int = 2
- def setup(self):
- self.dropout_layer = nn.Dropout(rate=self.config.dropout)
- embed_dim = self.config.hidden_size
- self.padding_idx = self.config.pad_token_id
- self.max_target_positions = self.config.max_position_embeddings
- self.embed_tokens = nn.Embed(
- self.config.vocab_size,
- self.config.word_embed_proj_dim,
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
- dtype=self.dtype,
- )
- self.embed_positions = FlaxOPTLearnedPositionalEmbedding(
- self.config.max_position_embeddings,
- embed_dim,
- embedding_init=jax.nn.initializers.normal(self.config.init_std),
- dtype=self.dtype,
- )
- if self.config.word_embed_proj_dim != self.config.hidden_size:
- self.project_in = nn.Dense(self.config.hidden_size, use_bias=False)
- self.project_out = nn.Dense(self.config.word_embed_proj_dim, use_bias=False)
- else:
- self.project_in = None
- self.project_out = None
- # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
- # with checkpoints that have been fine-tuned before transformers v4.20.1
- # see https://github.com/facebookresearch/metaseq/pull/164
- if self.config.do_layer_norm_before and not self.config._remove_final_layer_norm:
- self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
- else:
- self.final_layer_norm = None
- self.layers = FlaxOPTDecoderLayerCollection(self.config, self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- position_ids,
- init_cache: bool = False,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- deterministic: bool = True,
- ):
- input_shape = input_ids.shape
- input_ids = input_ids.reshape(-1, input_shape[-1])
- inputs_embeds = self.embed_tokens(input_ids)
- if self.project_in is not None:
- inputs_embeds = self.project_in(inputs_embeds)
- positions = self.embed_positions(position_ids)
- hidden_states = inputs_embeds + positions
- hidden_state, all_hidden_states, attentions = self.layers(
- hidden_states,
- attention_mask,
- deterministic=deterministic,
- init_cache=init_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- )
- if self.final_layer_norm is not None:
- hidden_state = self.final_layer_norm(hidden_state)
- if self.project_out is not None:
- hidden_state = self.project_out(hidden_state)
- if output_hidden_states:
- all_hidden_states += (hidden_state,)
- outputs = [hidden_state, all_hidden_states, attentions]
- if not return_dict:
- return tuple(v for v in outputs if v is not None)
- return FlaxBaseModelOutput(
- last_hidden_state=hidden_state,
- hidden_states=all_hidden_states,
- attentions=attentions,
- )
- class FlaxOPTPreTrainedModel(FlaxPreTrainedModel):
- config_class = OPTConfig
- base_model_prefix: str = "model"
- module_class: nn.Module = None
- def __init__(
- self,
- config: OPTConfig,
- input_shape: Tuple[int] = (1, 1),
- seed: int = 0,
- dtype: jnp.dtype = jnp.float32,
- _do_init: bool = True,
- **kwargs,
- ):
- module = self.module_class(config=config, dtype=dtype, **kwargs)
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
- # init input tensors
- input_ids = jnp.zeros(input_shape, dtype="i4")
- attention_mask = jnp.ones_like(input_ids)
- batch_size, sequence_length = input_ids.shape
- position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
- params_rng, dropout_rng = jax.random.split(rng)
- rngs = {"params": params_rng, "dropout": dropout_rng}
- module_init_outputs = self.module.init(
- rngs,
- input_ids,
- attention_mask,
- position_ids,
- return_dict=False,
- )
- random_params = module_init_outputs["params"]
- if params is not None:
- random_params = flatten_dict(unfreeze(random_params))
- params = flatten_dict(unfreeze(params))
- for missing_key in self._missing_keys:
- params[missing_key] = random_params[missing_key]
- self._missing_keys = set()
- return freeze(unflatten_dict(params))
- else:
- return random_params
- def init_cache(self, batch_size, max_length):
- r"""
- Args:
- batch_size (`int`):
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
- max_length (`int`):
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
- cache.
- """
- # init input variables to retrieve cache
- input_ids = jnp.ones((batch_size, max_length), dtype="i4")
- attention_mask = jnp.ones_like(input_ids, dtype="i4")
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
- init_variables = self.module.init(
- jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
- )
- return unfreeze(init_variables["cache"])
- def __call__(
- self,
- input_ids: jnp.ndarray,
- attention_mask: Optional[jnp.ndarray] = None,
- position_ids: Optional[jnp.ndarray] = None,
- params: dict = None,
- past_key_values: dict = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- dropout_rng: PRNGKey = None,
- deterministic: bool = True,
- ):
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- if attention_mask is None:
- attention_mask = jnp.ones_like(input_ids)
- if position_ids is None:
- position_ids = (attention_mask.cumsum(axis=1) * attention_mask) - 1
- # Handle any PRNG if needed
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
- inputs = {"params": params or self.params}
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
- # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
- # changed by FlaxOPTAttention module
- if past_key_values:
- inputs["cache"] = past_key_values
- mutable = ["cache"]
- else:
- mutable = False
- outputs = self.module.apply(
- inputs,
- input_ids=jnp.array(input_ids, dtype="i4"),
- attention_mask=jnp.array(attention_mask, dtype="i4"),
- position_ids=jnp.array(position_ids, dtype="i4"),
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- deterministic=deterministic,
- rngs=rngs,
- mutable=mutable,
- )
- # add updated cache to model output
- if past_key_values is not None and return_dict:
- outputs, past_key_values = outputs
- outputs["past_key_values"] = unfreeze(past_key_values["cache"])
- return outputs
- elif past_key_values is not None and not return_dict:
- outputs, past_key_values = outputs
- outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
- return outputs
- class FlaxOPTModule(nn.Module):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.decoder = FlaxOPTDecoder(self.config, dtype=self.dtype)
- def _get_decoder_module(self):
- return self.decoder
- def __call__(
- self,
- input_ids,
- attention_mask,
- position_ids,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- deterministic: bool = True,
- init_cache=False,
- ):
- decoder_outputs = self.decoder(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- deterministic=deterministic,
- init_cache=init_cache,
- )
- if not return_dict:
- return decoder_outputs
- return FlaxBaseModelOutput(
- last_hidden_state=decoder_outputs.last_hidden_state,
- hidden_states=decoder_outputs.hidden_states,
- attentions=decoder_outputs.attentions,
- )
- # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartModel with Bart->OPT
- class FlaxOPTModel(FlaxOPTPreTrainedModel):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- module_class = FlaxOPTModule
- append_call_sample_docstring(FlaxOPTModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
- @add_start_docstrings(
- "The bare OPT Model transformer outputting raw hidden-states without any specific head on top.",
- OPT_START_DOCSTRING,
- )
- class FlaxOPTForCausalLMModule(nn.Module):
- config: OPTConfig
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.model = FlaxOPTModule(config=self.config, dtype=self.dtype)
- self.lm_head = nn.Dense(
- self.config.vocab_size,
- use_bias=False,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.init_std),
- )
- def __call__(
- self,
- input_ids,
- attention_mask,
- position_ids,
- init_cache: bool = False,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- deterministic: bool = True,
- ):
- outputs = self.model(
- input_ids,
- attention_mask,
- position_ids,
- init_cache=init_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- deterministic=deterministic,
- )
- hidden_states = outputs[0]
- if self.config.tie_word_embeddings:
- shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
- lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
- else:
- lm_logits = self.lm_head(hidden_states)
- if not return_dict:
- return (lm_logits,) + outputs[1:]
- return FlaxMaskedLMOutput(
- logits=lm_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- OPT Model with a language modeling head on top (linear layer with weights tied to the input embeddings) e.g for
- autoregressive tasks.
- """,
- OPT_START_DOCSTRING,
- )
- class FlaxOPTForCausalLM(FlaxOPTPreTrainedModel):
- module_class = FlaxOPTForCausalLMModule
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
- # initializing the cache
- batch_size, seq_length = input_ids.shape
- past_key_values = self.init_cache(batch_size, max_length)
- # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
- # But since the decoder uses a causal mask, those positions are masked anyway.
- # Thus, we can create a single static attention_mask here, which is more efficient for compilation
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
- if attention_mask is not None:
- position_ids = attention_mask.cumsum(axis=1) - 1
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
- else:
- position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
- return {
- "past_key_values": past_key_values,
- "attention_mask": extended_attention_mask,
- "position_ids": position_ids,
- }
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
- model_kwargs["past_key_values"] = model_outputs.past_key_values
- model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
- return model_kwargs
- append_call_sample_docstring(
- FlaxOPTForCausalLM,
- _CHECKPOINT_FOR_DOC,
- FlaxBaseModelOutput,
- _CONFIG_FOR_DOC,
- )
|