| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713 |
- # coding=utf-8
- # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Callable, Optional, Tuple
- import flax
- import flax.linen as nn
- import jax
- import jax.numpy as jnp
- import numpy as np
- from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
- from flax.linen import combine_masks, make_causal_mask
- from flax.linen import partitioning as nn_partitioning
- from flax.linen.attention import dot_product_attention_weights
- from flax.traverse_util import flatten_dict, unflatten_dict
- from jax import lax
- from ...modeling_flax_outputs import (
- FlaxBaseModelOutputWithPastAndCrossAttentions,
- FlaxBaseModelOutputWithPooling,
- FlaxBaseModelOutputWithPoolingAndCrossAttentions,
- FlaxCausalLMOutputWithCrossAttentions,
- FlaxMaskedLMOutput,
- FlaxMultipleChoiceModelOutput,
- FlaxNextSentencePredictorOutput,
- FlaxQuestionAnsweringModelOutput,
- FlaxSequenceClassifierOutput,
- FlaxTokenClassifierOutput,
- )
- from ...modeling_flax_utils import (
- ACT2FN,
- FlaxPreTrainedModel,
- append_call_sample_docstring,
- append_replace_return_docstrings,
- overwrite_call_docstring,
- )
- from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
- from .configuration_bert import BertConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
- _CONFIG_FOR_DOC = "BertConfig"
- remat = nn_partitioning.remat
- @flax.struct.dataclass
- class FlaxBertForPreTrainingOutput(ModelOutput):
- """
- Output type of [`BertForPreTraining`].
- Args:
- prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
- before SoftMax).
- hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
- Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
- `(batch_size, sequence_length, hidden_size)`.
- Hidden-states of the model at the output of each layer plus the initial embedding outputs.
- attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
- Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
- sequence_length)`.
- Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
- heads.
- """
- prediction_logits: jnp.ndarray = None
- seq_relationship_logits: jnp.ndarray = None
- hidden_states: Optional[Tuple[jnp.ndarray]] = None
- attentions: Optional[Tuple[jnp.ndarray]] = None
- BERT_START_DOCSTRING = r"""
- This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
- This model is also a
- [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
- a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
- behavior.
- Finally, this model supports inherent JAX features such as:
- - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
- Parameters:
- config ([`BertConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
- `jax.numpy.bfloat16` (on TPUs).
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
- specified all the computation will be performed with the given `dtype`.
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
- parameters.**
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
- [`~FlaxPreTrainedModel.to_bf16`].
- dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
- The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
- `jax.numpy.bfloat16` (on TPUs).
- This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
- specified all the computation will be performed with the given `dtype`.
- **Note that this only specifies the dtype of the computation and does not influence the dtype of model
- parameters.**
- If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
- [`~FlaxPreTrainedModel.to_bf16`].
- """
- BERT_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`numpy.ndarray` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.max_position_embeddings - 1]`.
- head_mask (`numpy.ndarray` of shape `({0})`, `optional):
- Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- """
- class FlaxBertEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings."""
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.word_embeddings = nn.Embed(
- self.config.vocab_size,
- self.config.hidden_size,
- embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
- dtype=self.dtype,
- )
- self.position_embeddings = nn.Embed(
- self.config.max_position_embeddings,
- self.config.hidden_size,
- embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
- dtype=self.dtype,
- )
- self.token_type_embeddings = nn.Embed(
- self.config.type_vocab_size,
- self.config.hidden_size,
- embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
- dtype=self.dtype,
- )
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
- def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
- # Embed
- inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
- position_embeds = self.position_embeddings(position_ids.astype("i4"))
- token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
- # Sum all embeddings
- hidden_states = inputs_embeds + token_type_embeddings + position_embeds
- # Layer Norm
- hidden_states = self.LayerNorm(hidden_states)
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
- return hidden_states
- class FlaxBertSelfAttention(nn.Module):
- config: BertConfig
- causal: bool = False
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.head_dim = self.config.hidden_size // self.config.num_attention_heads
- if self.config.hidden_size % self.config.num_attention_heads != 0:
- raise ValueError(
- "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
- " : {self.config.num_attention_heads}"
- )
- self.query = nn.Dense(
- self.config.hidden_size,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- )
- self.key = nn.Dense(
- self.config.hidden_size,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- )
- self.value = nn.Dense(
- self.config.hidden_size,
- dtype=self.dtype,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- )
- if self.causal:
- self.causal_mask = make_causal_mask(
- jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
- )
- def _split_heads(self, hidden_states):
- return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
- def _merge_heads(self, hidden_states):
- return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
- @nn.compact
- # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
- def _concatenate_to_cache(self, key, value, query, attention_mask):
- """
- This function takes projected key, value states from a single input token and concatenates the states to cached
- states from previous steps. This function is slighly adapted from the official Flax repository:
- https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
- """
- # detect if we're initializing by absence of existing cache data.
- is_initialized = self.has_variable("cache", "cached_key")
- cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
- cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
- cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
- if is_initialized:
- *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
- # update key, value caches with our new 1d spatial slices
- cur_index = cache_index.value
- indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
- key = lax.dynamic_update_slice(cached_key.value, key, indices)
- value = lax.dynamic_update_slice(cached_value.value, value, indices)
- cached_key.value = key
- cached_value.value = value
- num_updated_cache_vectors = query.shape[1]
- cache_index.value = cache_index.value + num_updated_cache_vectors
- # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
- pad_mask = jnp.broadcast_to(
- jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
- tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
- )
- attention_mask = combine_masks(pad_mask, attention_mask)
- return key, value, attention_mask
- def __call__(
- self,
- hidden_states,
- attention_mask,
- layer_head_mask,
- key_value_states: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic=True,
- output_attentions: bool = False,
- ):
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- batch_size = hidden_states.shape[0]
- # get query proj
- query_states = self.query(hidden_states)
- # get key, value proj
- if is_cross_attention:
- # cross_attentions
- key_states = self.key(key_value_states)
- value_states = self.value(key_value_states)
- else:
- # self_attention
- key_states = self.key(hidden_states)
- value_states = self.value(hidden_states)
- query_states = self._split_heads(query_states)
- key_states = self._split_heads(key_states)
- value_states = self._split_heads(value_states)
- # handle cache prepare causal attention mask
- if self.causal:
- query_length, key_length = query_states.shape[1], key_states.shape[1]
- if self.has_variable("cache", "cached_key"):
- mask_shift = self.variables["cache"]["cache_index"]
- max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
- causal_mask = lax.dynamic_slice(
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
- )
- else:
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
- causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
- # combine masks if needed
- if attention_mask is not None and self.causal:
- attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
- attention_mask = combine_masks(attention_mask, causal_mask)
- elif self.causal:
- attention_mask = causal_mask
- elif attention_mask is not None:
- attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
- # During fast autoregressive decoding, we feed one position at a time,
- # and cache the keys and values step by step.
- if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
- key_states, value_states, attention_mask = self._concatenate_to_cache(
- key_states, value_states, query_states, attention_mask
- )
- # Convert the boolean attention mask to an attention bias.
- if attention_mask is not None:
- # attention mask in the form of attention bias
- attention_bias = lax.select(
- attention_mask > 0,
- jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
- jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
- )
- else:
- attention_bias = None
- dropout_rng = None
- if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
- dropout_rng = self.make_rng("dropout")
- attn_weights = dot_product_attention_weights(
- query_states,
- key_states,
- bias=attention_bias,
- dropout_rng=dropout_rng,
- dropout_rate=self.config.attention_probs_dropout_prob,
- broadcast_dropout=True,
- deterministic=deterministic,
- dtype=self.dtype,
- precision=None,
- )
- # Mask heads if we want to
- if layer_head_mask is not None:
- attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
- attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
- outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
- return outputs
- class FlaxBertSelfOutput(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.dense = nn.Dense(
- self.config.hidden_size,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- dtype=self.dtype,
- )
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
- def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class FlaxBertAttention(nn.Module):
- config: BertConfig
- causal: bool = False
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
- self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
- def __call__(
- self,
- hidden_states,
- attention_mask,
- layer_head_mask,
- key_value_states=None,
- init_cache=False,
- deterministic=True,
- output_attentions: bool = False,
- ):
- # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
- # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
- # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
- attn_outputs = self.self(
- hidden_states,
- attention_mask,
- layer_head_mask=layer_head_mask,
- key_value_states=key_value_states,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
- )
- attn_output = attn_outputs[0]
- hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attn_outputs[1],)
- return outputs
- class FlaxBertIntermediate(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.dense = nn.Dense(
- self.config.intermediate_size,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- dtype=self.dtype,
- )
- self.activation = ACT2FN[self.config.hidden_act]
- def __call__(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.activation(hidden_states)
- return hidden_states
- class FlaxBertOutput(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.dense = nn.Dense(
- self.config.hidden_size,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- dtype=self.dtype,
- )
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
- def __call__(self, hidden_states, attention_output, deterministic: bool = True):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
- hidden_states = self.LayerNorm(hidden_states + attention_output)
- return hidden_states
- class FlaxBertLayer(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
- self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
- self.output = FlaxBertOutput(self.config, dtype=self.dtype)
- if self.config.add_cross_attention:
- self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
- def __call__(
- self,
- hidden_states,
- attention_mask,
- layer_head_mask,
- encoder_hidden_states: Optional[jnp.ndarray] = None,
- encoder_attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- output_attentions: bool = False,
- ):
- # Self Attention
- attention_outputs = self.attention(
- hidden_states,
- attention_mask,
- layer_head_mask=layer_head_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
- )
- attention_output = attention_outputs[0]
- # Cross-Attention Block
- if encoder_hidden_states is not None:
- cross_attention_outputs = self.crossattention(
- attention_output,
- attention_mask=encoder_attention_mask,
- layer_head_mask=layer_head_mask,
- key_value_states=encoder_hidden_states,
- deterministic=deterministic,
- output_attentions=output_attentions,
- )
- attention_output = cross_attention_outputs[0]
- hidden_states = self.intermediate(attention_output)
- hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (attention_outputs[1],)
- if encoder_hidden_states is not None:
- outputs += (cross_attention_outputs[1],)
- return outputs
- class FlaxBertLayerCollection(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- gradient_checkpointing: bool = False
- def setup(self):
- if self.gradient_checkpointing:
- FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
- self.layers = [
- FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
- for i in range(self.config.num_hidden_layers)
- ]
- else:
- self.layers = [
- FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
- ]
- def __call__(
- self,
- hidden_states,
- attention_mask,
- head_mask,
- encoder_hidden_states: Optional[jnp.ndarray] = None,
- encoder_attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- all_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
- # Check if head_mask has a correct number of layers specified if desired
- if head_mask is not None:
- if head_mask.shape[0] != (len(self.layers)):
- raise ValueError(
- f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
- f" {head_mask.shape[0]}."
- )
- for i, layer in enumerate(self.layers):
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- layer_outputs = layer(
- hidden_states,
- attention_mask,
- head_mask[i] if head_mask is not None else None,
- encoder_hidden_states,
- encoder_attention_mask,
- init_cache,
- deterministic,
- output_attentions,
- )
- hidden_states = layer_outputs[0]
- if output_attentions:
- all_attentions += (layer_outputs[1],)
- if encoder_hidden_states is not None:
- all_cross_attentions += (layer_outputs[2],)
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
- if not return_dict:
- return tuple(v for v in outputs if v is not None)
- return FlaxBaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- hidden_states=all_hidden_states,
- attentions=all_attentions,
- cross_attentions=all_cross_attentions,
- )
- class FlaxBertEncoder(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- gradient_checkpointing: bool = False
- def setup(self):
- self.layer = FlaxBertLayerCollection(
- self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- def __call__(
- self,
- hidden_states,
- attention_mask,
- head_mask,
- encoder_hidden_states: Optional[jnp.ndarray] = None,
- encoder_attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- return self.layer(
- hidden_states,
- attention_mask,
- head_mask=head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- class FlaxBertPooler(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- def setup(self):
- self.dense = nn.Dense(
- self.config.hidden_size,
- kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
- dtype=self.dtype,
- )
- def __call__(self, hidden_states):
- cls_hidden_state = hidden_states[:, 0]
- cls_hidden_state = self.dense(cls_hidden_state)
- return nn.tanh(cls_hidden_state)
- class FlaxBertPredictionHeadTransform(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
- self.activation = ACT2FN[self.config.hidden_act]
- self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
- def __call__(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.activation(hidden_states)
- return self.LayerNorm(hidden_states)
- class FlaxBertLMPredictionHead(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
- def setup(self):
- self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
- self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
- self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
- def __call__(self, hidden_states, shared_embedding=None):
- hidden_states = self.transform(hidden_states)
- if shared_embedding is not None:
- hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
- else:
- hidden_states = self.decoder(hidden_states)
- bias = jnp.asarray(self.bias, self.dtype)
- hidden_states += bias
- return hidden_states
- class FlaxBertOnlyMLMHead(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
- def __call__(self, hidden_states, shared_embedding=None):
- hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
- return hidden_states
- class FlaxBertOnlyNSPHead(nn.Module):
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.seq_relationship = nn.Dense(2, dtype=self.dtype)
- def __call__(self, pooled_output):
- return self.seq_relationship(pooled_output)
- class FlaxBertPreTrainingHeads(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- def setup(self):
- self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
- self.seq_relationship = nn.Dense(2, dtype=self.dtype)
- def __call__(self, hidden_states, pooled_output, shared_embedding=None):
- prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
- seq_relationship_score = self.seq_relationship(pooled_output)
- return prediction_scores, seq_relationship_score
- class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = BertConfig
- base_model_prefix = "bert"
- module_class: nn.Module = None
- def __init__(
- self,
- config: BertConfig,
- input_shape: Tuple = (1, 1),
- seed: int = 0,
- dtype: jnp.dtype = jnp.float32,
- _do_init: bool = True,
- gradient_checkpointing: bool = False,
- **kwargs,
- ):
- module = self.module_class(
- config=config,
- dtype=dtype,
- gradient_checkpointing=gradient_checkpointing,
- **kwargs,
- )
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
- def enable_gradient_checkpointing(self):
- self._module = self.module_class(
- config=self.config,
- dtype=self.dtype,
- gradient_checkpointing=True,
- )
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
- # init input tensors
- input_ids = jnp.zeros(input_shape, dtype="i4")
- token_type_ids = jnp.zeros_like(input_ids)
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
- attention_mask = jnp.ones_like(input_ids)
- head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
- params_rng, dropout_rng = jax.random.split(rng)
- rngs = {"params": params_rng, "dropout": dropout_rng}
- if self.config.add_cross_attention:
- encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
- encoder_attention_mask = attention_mask
- module_init_outputs = self.module.init(
- rngs,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- encoder_hidden_states,
- encoder_attention_mask,
- return_dict=False,
- )
- else:
- module_init_outputs = self.module.init(
- rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
- )
- random_params = module_init_outputs["params"]
- if params is not None:
- random_params = flatten_dict(unfreeze(random_params))
- params = flatten_dict(unfreeze(params))
- for missing_key in self._missing_keys:
- params[missing_key] = random_params[missing_key]
- self._missing_keys = set()
- return freeze(unflatten_dict(params))
- else:
- return random_params
- # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
- def init_cache(self, batch_size, max_length):
- r"""
- Args:
- batch_size (`int`):
- batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
- max_length (`int`):
- maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
- cache.
- """
- # init input variables to retrieve cache
- input_ids = jnp.ones((batch_size, max_length), dtype="i4")
- attention_mask = jnp.ones_like(input_ids, dtype="i4")
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
- init_variables = self.module.init(
- jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
- )
- return unfreeze(init_variables["cache"])
- @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- def __call__(
- self,
- input_ids,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- params: dict = None,
- dropout_rng: jax.random.PRNGKey = None,
- train: bool = False,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- past_key_values: dict = None,
- ):
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # init input tensors if not passed
- if token_type_ids is None:
- token_type_ids = jnp.zeros_like(input_ids)
- if position_ids is None:
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
- if attention_mask is None:
- attention_mask = jnp.ones_like(input_ids)
- if head_mask is None:
- head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
- # Handle any PRNG if needed
- rngs = {}
- if dropout_rng is not None:
- rngs["dropout"] = dropout_rng
- inputs = {"params": params or self.params}
- if self.config.add_cross_attention:
- # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
- # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
- # changed by FlaxBertAttention module
- if past_key_values:
- inputs["cache"] = past_key_values
- mutable = ["cache"]
- else:
- mutable = False
- outputs = self.module.apply(
- inputs,
- jnp.array(input_ids, dtype="i4"),
- jnp.array(attention_mask, dtype="i4"),
- token_type_ids=jnp.array(token_type_ids, dtype="i4"),
- position_ids=jnp.array(position_ids, dtype="i4"),
- head_mask=jnp.array(head_mask, dtype="i4"),
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- deterministic=not train,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- rngs=rngs,
- mutable=mutable,
- )
- # add updated cache to model output
- if past_key_values is not None and return_dict:
- outputs, past_key_values = outputs
- outputs["past_key_values"] = unfreeze(past_key_values["cache"])
- return outputs
- elif past_key_values is not None and not return_dict:
- outputs, past_key_values = outputs
- outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
- else:
- outputs = self.module.apply(
- inputs,
- jnp.array(input_ids, dtype="i4"),
- jnp.array(attention_mask, dtype="i4"),
- token_type_ids=jnp.array(token_type_ids, dtype="i4"),
- position_ids=jnp.array(position_ids, dtype="i4"),
- head_mask=jnp.array(head_mask, dtype="i4"),
- deterministic=not train,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- rngs=rngs,
- )
- return outputs
- class FlaxBertModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32 # the dtype of the computation
- add_pooling_layer: bool = True
- gradient_checkpointing: bool = False
- def setup(self):
- self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
- self.encoder = FlaxBertEncoder(
- self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids: Optional[jnp.ndarray] = None,
- position_ids: Optional[jnp.ndarray] = None,
- head_mask: Optional[jnp.ndarray] = None,
- encoder_hidden_states: Optional[jnp.ndarray] = None,
- encoder_attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # make sure `token_type_ids` is correctly initialized when not passed
- if token_type_ids is None:
- token_type_ids = jnp.zeros_like(input_ids)
- # make sure `position_ids` is correctly initialized when not passed
- if position_ids is None:
- position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
- hidden_states = self.embeddings(
- input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
- )
- outputs = self.encoder(
- hidden_states,
- attention_mask,
- head_mask=head_mask,
- deterministic=deterministic,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- init_cache=init_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
- if not return_dict:
- # if pooled is None, don't return it
- if pooled is None:
- return (hidden_states,) + outputs[1:]
- return (hidden_states, pooled) + outputs[1:]
- return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=hidden_states,
- pooler_output=pooled,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @add_start_docstrings(
- "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
- BERT_START_DOCSTRING,
- )
- class FlaxBertModel(FlaxBertPreTrainedModel):
- module_class = FlaxBertModule
- append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
- class FlaxBertForPreTrainingModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- if self.config.tie_word_embeddings:
- shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
- else:
- shared_embedding = None
- hidden_states = outputs[0]
- pooled_output = outputs[1]
- prediction_scores, seq_relationship_score = self.cls(
- hidden_states, pooled_output, shared_embedding=shared_embedding
- )
- if not return_dict:
- return (prediction_scores, seq_relationship_score) + outputs[2:]
- return FlaxBertForPreTrainingOutput(
- prediction_logits=prediction_scores,
- seq_relationship_logits=seq_relationship_score,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
- sentence prediction (classification)` head.
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
- module_class = FlaxBertForPreTrainingModule
- FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, FlaxBertForPreTraining
- >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
- >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
- >>> outputs = model(**inputs)
- >>> prediction_logits = outputs.prediction_logits
- >>> seq_relationship_logits = outputs.seq_relationship_logits
- ```
- """
- overwrite_call_docstring(
- FlaxBertForPreTraining,
- BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
- )
- append_replace_return_docstrings(
- FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
- )
- class FlaxBertForMaskedLMModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- add_pooling_layer=False,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- if self.config.tie_word_embeddings:
- shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
- else:
- shared_embedding = None
- # Compute the prediction scores
- logits = self.cls(hidden_states, shared_embedding=shared_embedding)
- if not return_dict:
- return (logits,) + outputs[1:]
- return FlaxMaskedLMOutput(
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
- class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
- module_class = FlaxBertForMaskedLMModule
- append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
- class FlaxBertForNextSentencePredictionModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- return_dict = return_dict if return_dict is not None else self.config.return_dict
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- seq_relationship_scores = self.cls(pooled_output)
- if not return_dict:
- return (seq_relationship_scores,) + outputs[2:]
- return FlaxNextSentencePredictorOutput(
- logits=seq_relationship_scores,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """Bert Model with a `next sentence prediction (classification)` head on top.""",
- BERT_START_DOCSTRING,
- )
- class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
- module_class = FlaxBertForNextSentencePredictionModule
- FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction
- >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
- >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
- >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
- >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
- >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax")
- >>> outputs = model(**encoding)
- >>> logits = outputs.logits
- >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
- ```
- """
- overwrite_call_docstring(
- FlaxBertForNextSentencePrediction,
- BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
- )
- append_replace_return_docstrings(
- FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
- )
- class FlaxBertForSequenceClassificationModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- classifier_dropout = (
- self.config.classifier_dropout
- if self.config.classifier_dropout is not None
- else self.config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(rate=classifier_dropout)
- self.classifier = nn.Dense(
- self.config.num_labels,
- dtype=self.dtype,
- )
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output, deterministic=deterministic)
- logits = self.classifier(pooled_output)
- if not return_dict:
- return (logits,) + outputs[2:]
- return FlaxSequenceClassifierOutput(
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
- output) e.g. for GLUE tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
- module_class = FlaxBertForSequenceClassificationModule
- append_call_sample_docstring(
- FlaxBertForSequenceClassification,
- _CHECKPOINT_FOR_DOC,
- FlaxSequenceClassifierOutput,
- _CONFIG_FOR_DOC,
- )
- class FlaxBertForMultipleChoiceModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
- self.classifier = nn.Dense(1, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- num_choices = input_ids.shape[1]
- input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
- attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
- token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
- position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- pooled_output = outputs[1]
- pooled_output = self.dropout(pooled_output, deterministic=deterministic)
- logits = self.classifier(pooled_output)
- reshaped_logits = logits.reshape(-1, num_choices)
- if not return_dict:
- return (reshaped_logits,) + outputs[2:]
- return FlaxMultipleChoiceModelOutput(
- logits=reshaped_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
- softmax) e.g. for RocStories/SWAG tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
- module_class = FlaxBertForMultipleChoiceModule
- overwrite_call_docstring(
- FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
- )
- append_call_sample_docstring(
- FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
- )
- class FlaxBertForTokenClassificationModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- add_pooling_layer=False,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- classifier_dropout = (
- self.config.classifier_dropout
- if self.config.classifier_dropout is not None
- else self.config.hidden_dropout_prob
- )
- self.dropout = nn.Dropout(rate=classifier_dropout)
- self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- hidden_states = self.dropout(hidden_states, deterministic=deterministic)
- logits = self.classifier(hidden_states)
- if not return_dict:
- return (logits,) + outputs[1:]
- return FlaxTokenClassifierOutput(
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
- Named-Entity-Recognition (NER) tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
- module_class = FlaxBertForTokenClassificationModule
- append_call_sample_docstring(
- FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
- )
- class FlaxBertForQuestionAnsweringModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- dtype=self.dtype,
- add_pooling_layer=False,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- logits = self.qa_outputs(hidden_states)
- start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
- start_logits = start_logits.squeeze(-1)
- end_logits = end_logits.squeeze(-1)
- if not return_dict:
- return (start_logits, end_logits) + outputs[1:]
- return FlaxQuestionAnsweringModelOutput(
- start_logits=start_logits,
- end_logits=end_logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
- layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
- module_class = FlaxBertForQuestionAnsweringModule
- append_call_sample_docstring(
- FlaxBertForQuestionAnswering,
- _CHECKPOINT_FOR_DOC,
- FlaxQuestionAnsweringModelOutput,
- _CONFIG_FOR_DOC,
- )
- class FlaxBertForCausalLMModule(nn.Module):
- config: BertConfig
- dtype: jnp.dtype = jnp.float32
- gradient_checkpointing: bool = False
- def setup(self):
- self.bert = FlaxBertModule(
- config=self.config,
- add_pooling_layer=False,
- dtype=self.dtype,
- gradient_checkpointing=self.gradient_checkpointing,
- )
- self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
- def __call__(
- self,
- input_ids,
- attention_mask,
- position_ids,
- token_type_ids: Optional[jnp.ndarray] = None,
- head_mask: Optional[jnp.ndarray] = None,
- encoder_hidden_states: Optional[jnp.ndarray] = None,
- encoder_attention_mask: Optional[jnp.ndarray] = None,
- init_cache: bool = False,
- deterministic: bool = True,
- output_attentions: bool = False,
- output_hidden_states: bool = False,
- return_dict: bool = True,
- ):
- # Model
- outputs = self.bert(
- input_ids,
- attention_mask,
- token_type_ids,
- position_ids,
- head_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- init_cache=init_cache,
- deterministic=deterministic,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = outputs[0]
- if self.config.tie_word_embeddings:
- shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
- else:
- shared_embedding = None
- # Compute the prediction scores
- logits = self.cls(hidden_states, shared_embedding=shared_embedding)
- if not return_dict:
- return (logits,) + outputs[1:]
- return FlaxCausalLMOutputWithCrossAttentions(
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- cross_attentions=outputs.cross_attentions,
- )
- @add_start_docstrings(
- """
- Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
- autoregressive tasks.
- """,
- BERT_START_DOCSTRING,
- )
- class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
- module_class = FlaxBertForCausalLMModule
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
- # initializing the cache
- batch_size, seq_length = input_ids.shape
- past_key_values = self.init_cache(batch_size, max_length)
- # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
- # But since the decoder uses a causal mask, those positions are masked anyway.
- # Thus, we can create a single static attention_mask here, which is more efficient for compilation
- extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
- if attention_mask is not None:
- position_ids = attention_mask.cumsum(axis=-1) - 1
- extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
- else:
- position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
- return {
- "past_key_values": past_key_values,
- "attention_mask": extended_attention_mask,
- "position_ids": position_ids,
- }
- def update_inputs_for_generation(self, model_outputs, model_kwargs):
- model_kwargs["past_key_values"] = model_outputs.past_key_values
- model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
- return model_kwargs
- append_call_sample_docstring(
- FlaxBertForCausalLM,
- _CHECKPOINT_FOR_DOC,
- FlaxCausalLMOutputWithCrossAttentions,
- _CONFIG_FOR_DOC,
- )
|