modeling_flax_bert.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713
  1. # coding=utf-8
  2. # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional, Tuple
  16. import flax
  17. import flax.linen as nn
  18. import jax
  19. import jax.numpy as jnp
  20. import numpy as np
  21. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  22. from flax.linen import combine_masks, make_causal_mask
  23. from flax.linen import partitioning as nn_partitioning
  24. from flax.linen.attention import dot_product_attention_weights
  25. from flax.traverse_util import flatten_dict, unflatten_dict
  26. from jax import lax
  27. from ...modeling_flax_outputs import (
  28. FlaxBaseModelOutputWithPastAndCrossAttentions,
  29. FlaxBaseModelOutputWithPooling,
  30. FlaxBaseModelOutputWithPoolingAndCrossAttentions,
  31. FlaxCausalLMOutputWithCrossAttentions,
  32. FlaxMaskedLMOutput,
  33. FlaxMultipleChoiceModelOutput,
  34. FlaxNextSentencePredictorOutput,
  35. FlaxQuestionAnsweringModelOutput,
  36. FlaxSequenceClassifierOutput,
  37. FlaxTokenClassifierOutput,
  38. )
  39. from ...modeling_flax_utils import (
  40. ACT2FN,
  41. FlaxPreTrainedModel,
  42. append_call_sample_docstring,
  43. append_replace_return_docstrings,
  44. overwrite_call_docstring,
  45. )
  46. from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  47. from .configuration_bert import BertConfig
  48. logger = logging.get_logger(__name__)
  49. _CHECKPOINT_FOR_DOC = "google-bert/bert-base-uncased"
  50. _CONFIG_FOR_DOC = "BertConfig"
  51. remat = nn_partitioning.remat
  52. @flax.struct.dataclass
  53. class FlaxBertForPreTrainingOutput(ModelOutput):
  54. """
  55. Output type of [`BertForPreTraining`].
  56. Args:
  57. prediction_logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
  58. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  59. seq_relationship_logits (`jnp.ndarray` of shape `(batch_size, 2)`):
  60. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  61. before SoftMax).
  62. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  63. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  64. `(batch_size, sequence_length, hidden_size)`.
  65. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  66. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  67. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  68. sequence_length)`.
  69. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  70. heads.
  71. """
  72. prediction_logits: jnp.ndarray = None
  73. seq_relationship_logits: jnp.ndarray = None
  74. hidden_states: Optional[Tuple[jnp.ndarray]] = None
  75. attentions: Optional[Tuple[jnp.ndarray]] = None
  76. BERT_START_DOCSTRING = r"""
  77. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  78. library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
  79. This model is also a
  80. [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
  81. a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
  82. behavior.
  83. Finally, this model supports inherent JAX features such as:
  84. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  85. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  86. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  87. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  88. Parameters:
  89. config ([`BertConfig`]): Model configuration class with all the parameters of the model.
  90. Initializing with a config file does not load the weights associated with the model, only the
  91. configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
  92. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
  93. The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
  94. `jax.numpy.bfloat16` (on TPUs).
  95. This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
  96. specified all the computation will be performed with the given `dtype`.
  97. **Note that this only specifies the dtype of the computation and does not influence the dtype of model
  98. parameters.**
  99. If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
  100. [`~FlaxPreTrainedModel.to_bf16`].
  101. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
  102. The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
  103. `jax.numpy.bfloat16` (on TPUs).
  104. This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
  105. specified all the computation will be performed with the given `dtype`.
  106. **Note that this only specifies the dtype of the computation and does not influence the dtype of model
  107. parameters.**
  108. If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
  109. [`~FlaxPreTrainedModel.to_bf16`].
  110. """
  111. BERT_INPUTS_DOCSTRING = r"""
  112. Args:
  113. input_ids (`numpy.ndarray` of shape `({0})`):
  114. Indices of input sequence tokens in the vocabulary.
  115. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  116. [`PreTrainedTokenizer.__call__`] for details.
  117. [What are input IDs?](../glossary#input-ids)
  118. attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
  119. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  120. - 1 for tokens that are **not masked**,
  121. - 0 for tokens that are **masked**.
  122. [What are attention masks?](../glossary#attention-mask)
  123. token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  124. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  125. 1]`:
  126. - 0 corresponds to a *sentence A* token,
  127. - 1 corresponds to a *sentence B* token.
  128. [What are token type IDs?](../glossary#token-type-ids)
  129. position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  130. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  131. config.max_position_embeddings - 1]`.
  132. head_mask (`numpy.ndarray` of shape `({0})`, `optional):
  133. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  134. - 1 indicates the head is **not masked**,
  135. - 0 indicates the head is **masked**.
  136. return_dict (`bool`, *optional*):
  137. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  138. """
  139. class FlaxBertEmbeddings(nn.Module):
  140. """Construct the embeddings from word, position and token_type embeddings."""
  141. config: BertConfig
  142. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  143. def setup(self):
  144. self.word_embeddings = nn.Embed(
  145. self.config.vocab_size,
  146. self.config.hidden_size,
  147. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  148. dtype=self.dtype,
  149. )
  150. self.position_embeddings = nn.Embed(
  151. self.config.max_position_embeddings,
  152. self.config.hidden_size,
  153. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  154. dtype=self.dtype,
  155. )
  156. self.token_type_embeddings = nn.Embed(
  157. self.config.type_vocab_size,
  158. self.config.hidden_size,
  159. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  160. dtype=self.dtype,
  161. )
  162. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  163. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  164. def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
  165. # Embed
  166. inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
  167. position_embeds = self.position_embeddings(position_ids.astype("i4"))
  168. token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
  169. # Sum all embeddings
  170. hidden_states = inputs_embeds + token_type_embeddings + position_embeds
  171. # Layer Norm
  172. hidden_states = self.LayerNorm(hidden_states)
  173. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  174. return hidden_states
  175. class FlaxBertSelfAttention(nn.Module):
  176. config: BertConfig
  177. causal: bool = False
  178. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  179. def setup(self):
  180. self.head_dim = self.config.hidden_size // self.config.num_attention_heads
  181. if self.config.hidden_size % self.config.num_attention_heads != 0:
  182. raise ValueError(
  183. "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
  184. " : {self.config.num_attention_heads}"
  185. )
  186. self.query = nn.Dense(
  187. self.config.hidden_size,
  188. dtype=self.dtype,
  189. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  190. )
  191. self.key = nn.Dense(
  192. self.config.hidden_size,
  193. dtype=self.dtype,
  194. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  195. )
  196. self.value = nn.Dense(
  197. self.config.hidden_size,
  198. dtype=self.dtype,
  199. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  200. )
  201. if self.causal:
  202. self.causal_mask = make_causal_mask(
  203. jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
  204. )
  205. def _split_heads(self, hidden_states):
  206. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
  207. def _merge_heads(self, hidden_states):
  208. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
  209. @nn.compact
  210. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
  211. def _concatenate_to_cache(self, key, value, query, attention_mask):
  212. """
  213. This function takes projected key, value states from a single input token and concatenates the states to cached
  214. states from previous steps. This function is slighly adapted from the official Flax repository:
  215. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  216. """
  217. # detect if we're initializing by absence of existing cache data.
  218. is_initialized = self.has_variable("cache", "cached_key")
  219. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  220. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  221. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  222. if is_initialized:
  223. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  224. # update key, value caches with our new 1d spatial slices
  225. cur_index = cache_index.value
  226. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  227. key = lax.dynamic_update_slice(cached_key.value, key, indices)
  228. value = lax.dynamic_update_slice(cached_value.value, value, indices)
  229. cached_key.value = key
  230. cached_value.value = value
  231. num_updated_cache_vectors = query.shape[1]
  232. cache_index.value = cache_index.value + num_updated_cache_vectors
  233. # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
  234. pad_mask = jnp.broadcast_to(
  235. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  236. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  237. )
  238. attention_mask = combine_masks(pad_mask, attention_mask)
  239. return key, value, attention_mask
  240. def __call__(
  241. self,
  242. hidden_states,
  243. attention_mask,
  244. layer_head_mask,
  245. key_value_states: Optional[jnp.ndarray] = None,
  246. init_cache: bool = False,
  247. deterministic=True,
  248. output_attentions: bool = False,
  249. ):
  250. # if key_value_states are provided this layer is used as a cross-attention layer
  251. # for the decoder
  252. is_cross_attention = key_value_states is not None
  253. batch_size = hidden_states.shape[0]
  254. # get query proj
  255. query_states = self.query(hidden_states)
  256. # get key, value proj
  257. if is_cross_attention:
  258. # cross_attentions
  259. key_states = self.key(key_value_states)
  260. value_states = self.value(key_value_states)
  261. else:
  262. # self_attention
  263. key_states = self.key(hidden_states)
  264. value_states = self.value(hidden_states)
  265. query_states = self._split_heads(query_states)
  266. key_states = self._split_heads(key_states)
  267. value_states = self._split_heads(value_states)
  268. # handle cache prepare causal attention mask
  269. if self.causal:
  270. query_length, key_length = query_states.shape[1], key_states.shape[1]
  271. if self.has_variable("cache", "cached_key"):
  272. mask_shift = self.variables["cache"]["cache_index"]
  273. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  274. causal_mask = lax.dynamic_slice(
  275. self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
  276. )
  277. else:
  278. causal_mask = self.causal_mask[:, :, :query_length, :key_length]
  279. causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
  280. # combine masks if needed
  281. if attention_mask is not None and self.causal:
  282. attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
  283. attention_mask = combine_masks(attention_mask, causal_mask)
  284. elif self.causal:
  285. attention_mask = causal_mask
  286. elif attention_mask is not None:
  287. attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
  288. # During fast autoregressive decoding, we feed one position at a time,
  289. # and cache the keys and values step by step.
  290. if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
  291. key_states, value_states, attention_mask = self._concatenate_to_cache(
  292. key_states, value_states, query_states, attention_mask
  293. )
  294. # Convert the boolean attention mask to an attention bias.
  295. if attention_mask is not None:
  296. # attention mask in the form of attention bias
  297. attention_bias = lax.select(
  298. attention_mask > 0,
  299. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  300. jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
  301. )
  302. else:
  303. attention_bias = None
  304. dropout_rng = None
  305. if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
  306. dropout_rng = self.make_rng("dropout")
  307. attn_weights = dot_product_attention_weights(
  308. query_states,
  309. key_states,
  310. bias=attention_bias,
  311. dropout_rng=dropout_rng,
  312. dropout_rate=self.config.attention_probs_dropout_prob,
  313. broadcast_dropout=True,
  314. deterministic=deterministic,
  315. dtype=self.dtype,
  316. precision=None,
  317. )
  318. # Mask heads if we want to
  319. if layer_head_mask is not None:
  320. attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
  321. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  322. attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
  323. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  324. return outputs
  325. class FlaxBertSelfOutput(nn.Module):
  326. config: BertConfig
  327. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  328. def setup(self):
  329. self.dense = nn.Dense(
  330. self.config.hidden_size,
  331. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  332. dtype=self.dtype,
  333. )
  334. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  335. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  336. def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
  337. hidden_states = self.dense(hidden_states)
  338. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  339. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  340. return hidden_states
  341. class FlaxBertAttention(nn.Module):
  342. config: BertConfig
  343. causal: bool = False
  344. dtype: jnp.dtype = jnp.float32
  345. def setup(self):
  346. self.self = FlaxBertSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
  347. self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)
  348. def __call__(
  349. self,
  350. hidden_states,
  351. attention_mask,
  352. layer_head_mask,
  353. key_value_states=None,
  354. init_cache=False,
  355. deterministic=True,
  356. output_attentions: bool = False,
  357. ):
  358. # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
  359. # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
  360. # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
  361. attn_outputs = self.self(
  362. hidden_states,
  363. attention_mask,
  364. layer_head_mask=layer_head_mask,
  365. key_value_states=key_value_states,
  366. init_cache=init_cache,
  367. deterministic=deterministic,
  368. output_attentions=output_attentions,
  369. )
  370. attn_output = attn_outputs[0]
  371. hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
  372. outputs = (hidden_states,)
  373. if output_attentions:
  374. outputs += (attn_outputs[1],)
  375. return outputs
  376. class FlaxBertIntermediate(nn.Module):
  377. config: BertConfig
  378. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  379. def setup(self):
  380. self.dense = nn.Dense(
  381. self.config.intermediate_size,
  382. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  383. dtype=self.dtype,
  384. )
  385. self.activation = ACT2FN[self.config.hidden_act]
  386. def __call__(self, hidden_states):
  387. hidden_states = self.dense(hidden_states)
  388. hidden_states = self.activation(hidden_states)
  389. return hidden_states
  390. class FlaxBertOutput(nn.Module):
  391. config: BertConfig
  392. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  393. def setup(self):
  394. self.dense = nn.Dense(
  395. self.config.hidden_size,
  396. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  397. dtype=self.dtype,
  398. )
  399. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  400. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  401. def __call__(self, hidden_states, attention_output, deterministic: bool = True):
  402. hidden_states = self.dense(hidden_states)
  403. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  404. hidden_states = self.LayerNorm(hidden_states + attention_output)
  405. return hidden_states
  406. class FlaxBertLayer(nn.Module):
  407. config: BertConfig
  408. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  409. def setup(self):
  410. self.attention = FlaxBertAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
  411. self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)
  412. self.output = FlaxBertOutput(self.config, dtype=self.dtype)
  413. if self.config.add_cross_attention:
  414. self.crossattention = FlaxBertAttention(self.config, causal=False, dtype=self.dtype)
  415. def __call__(
  416. self,
  417. hidden_states,
  418. attention_mask,
  419. layer_head_mask,
  420. encoder_hidden_states: Optional[jnp.ndarray] = None,
  421. encoder_attention_mask: Optional[jnp.ndarray] = None,
  422. init_cache: bool = False,
  423. deterministic: bool = True,
  424. output_attentions: bool = False,
  425. ):
  426. # Self Attention
  427. attention_outputs = self.attention(
  428. hidden_states,
  429. attention_mask,
  430. layer_head_mask=layer_head_mask,
  431. init_cache=init_cache,
  432. deterministic=deterministic,
  433. output_attentions=output_attentions,
  434. )
  435. attention_output = attention_outputs[0]
  436. # Cross-Attention Block
  437. if encoder_hidden_states is not None:
  438. cross_attention_outputs = self.crossattention(
  439. attention_output,
  440. attention_mask=encoder_attention_mask,
  441. layer_head_mask=layer_head_mask,
  442. key_value_states=encoder_hidden_states,
  443. deterministic=deterministic,
  444. output_attentions=output_attentions,
  445. )
  446. attention_output = cross_attention_outputs[0]
  447. hidden_states = self.intermediate(attention_output)
  448. hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
  449. outputs = (hidden_states,)
  450. if output_attentions:
  451. outputs += (attention_outputs[1],)
  452. if encoder_hidden_states is not None:
  453. outputs += (cross_attention_outputs[1],)
  454. return outputs
  455. class FlaxBertLayerCollection(nn.Module):
  456. config: BertConfig
  457. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  458. gradient_checkpointing: bool = False
  459. def setup(self):
  460. if self.gradient_checkpointing:
  461. FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7))
  462. self.layers = [
  463. FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
  464. for i in range(self.config.num_hidden_layers)
  465. ]
  466. else:
  467. self.layers = [
  468. FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
  469. ]
  470. def __call__(
  471. self,
  472. hidden_states,
  473. attention_mask,
  474. head_mask,
  475. encoder_hidden_states: Optional[jnp.ndarray] = None,
  476. encoder_attention_mask: Optional[jnp.ndarray] = None,
  477. init_cache: bool = False,
  478. deterministic: bool = True,
  479. output_attentions: bool = False,
  480. output_hidden_states: bool = False,
  481. return_dict: bool = True,
  482. ):
  483. all_attentions = () if output_attentions else None
  484. all_hidden_states = () if output_hidden_states else None
  485. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  486. # Check if head_mask has a correct number of layers specified if desired
  487. if head_mask is not None:
  488. if head_mask.shape[0] != (len(self.layers)):
  489. raise ValueError(
  490. f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
  491. f" {head_mask.shape[0]}."
  492. )
  493. for i, layer in enumerate(self.layers):
  494. if output_hidden_states:
  495. all_hidden_states += (hidden_states,)
  496. layer_outputs = layer(
  497. hidden_states,
  498. attention_mask,
  499. head_mask[i] if head_mask is not None else None,
  500. encoder_hidden_states,
  501. encoder_attention_mask,
  502. init_cache,
  503. deterministic,
  504. output_attentions,
  505. )
  506. hidden_states = layer_outputs[0]
  507. if output_attentions:
  508. all_attentions += (layer_outputs[1],)
  509. if encoder_hidden_states is not None:
  510. all_cross_attentions += (layer_outputs[2],)
  511. if output_hidden_states:
  512. all_hidden_states += (hidden_states,)
  513. outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
  514. if not return_dict:
  515. return tuple(v for v in outputs if v is not None)
  516. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  517. last_hidden_state=hidden_states,
  518. hidden_states=all_hidden_states,
  519. attentions=all_attentions,
  520. cross_attentions=all_cross_attentions,
  521. )
  522. class FlaxBertEncoder(nn.Module):
  523. config: BertConfig
  524. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  525. gradient_checkpointing: bool = False
  526. def setup(self):
  527. self.layer = FlaxBertLayerCollection(
  528. self.config,
  529. dtype=self.dtype,
  530. gradient_checkpointing=self.gradient_checkpointing,
  531. )
  532. def __call__(
  533. self,
  534. hidden_states,
  535. attention_mask,
  536. head_mask,
  537. encoder_hidden_states: Optional[jnp.ndarray] = None,
  538. encoder_attention_mask: Optional[jnp.ndarray] = None,
  539. init_cache: bool = False,
  540. deterministic: bool = True,
  541. output_attentions: bool = False,
  542. output_hidden_states: bool = False,
  543. return_dict: bool = True,
  544. ):
  545. return self.layer(
  546. hidden_states,
  547. attention_mask,
  548. head_mask=head_mask,
  549. encoder_hidden_states=encoder_hidden_states,
  550. encoder_attention_mask=encoder_attention_mask,
  551. init_cache=init_cache,
  552. deterministic=deterministic,
  553. output_attentions=output_attentions,
  554. output_hidden_states=output_hidden_states,
  555. return_dict=return_dict,
  556. )
  557. class FlaxBertPooler(nn.Module):
  558. config: BertConfig
  559. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  560. def setup(self):
  561. self.dense = nn.Dense(
  562. self.config.hidden_size,
  563. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  564. dtype=self.dtype,
  565. )
  566. def __call__(self, hidden_states):
  567. cls_hidden_state = hidden_states[:, 0]
  568. cls_hidden_state = self.dense(cls_hidden_state)
  569. return nn.tanh(cls_hidden_state)
  570. class FlaxBertPredictionHeadTransform(nn.Module):
  571. config: BertConfig
  572. dtype: jnp.dtype = jnp.float32
  573. def setup(self):
  574. self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
  575. self.activation = ACT2FN[self.config.hidden_act]
  576. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  577. def __call__(self, hidden_states):
  578. hidden_states = self.dense(hidden_states)
  579. hidden_states = self.activation(hidden_states)
  580. return self.LayerNorm(hidden_states)
  581. class FlaxBertLMPredictionHead(nn.Module):
  582. config: BertConfig
  583. dtype: jnp.dtype = jnp.float32
  584. bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
  585. def setup(self):
  586. self.transform = FlaxBertPredictionHeadTransform(self.config, dtype=self.dtype)
  587. self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype, use_bias=False)
  588. self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
  589. def __call__(self, hidden_states, shared_embedding=None):
  590. hidden_states = self.transform(hidden_states)
  591. if shared_embedding is not None:
  592. hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
  593. else:
  594. hidden_states = self.decoder(hidden_states)
  595. bias = jnp.asarray(self.bias, self.dtype)
  596. hidden_states += bias
  597. return hidden_states
  598. class FlaxBertOnlyMLMHead(nn.Module):
  599. config: BertConfig
  600. dtype: jnp.dtype = jnp.float32
  601. def setup(self):
  602. self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
  603. def __call__(self, hidden_states, shared_embedding=None):
  604. hidden_states = self.predictions(hidden_states, shared_embedding=shared_embedding)
  605. return hidden_states
  606. class FlaxBertOnlyNSPHead(nn.Module):
  607. dtype: jnp.dtype = jnp.float32
  608. def setup(self):
  609. self.seq_relationship = nn.Dense(2, dtype=self.dtype)
  610. def __call__(self, pooled_output):
  611. return self.seq_relationship(pooled_output)
  612. class FlaxBertPreTrainingHeads(nn.Module):
  613. config: BertConfig
  614. dtype: jnp.dtype = jnp.float32
  615. def setup(self):
  616. self.predictions = FlaxBertLMPredictionHead(self.config, dtype=self.dtype)
  617. self.seq_relationship = nn.Dense(2, dtype=self.dtype)
  618. def __call__(self, hidden_states, pooled_output, shared_embedding=None):
  619. prediction_scores = self.predictions(hidden_states, shared_embedding=shared_embedding)
  620. seq_relationship_score = self.seq_relationship(pooled_output)
  621. return prediction_scores, seq_relationship_score
  622. class FlaxBertPreTrainedModel(FlaxPreTrainedModel):
  623. """
  624. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  625. models.
  626. """
  627. config_class = BertConfig
  628. base_model_prefix = "bert"
  629. module_class: nn.Module = None
  630. def __init__(
  631. self,
  632. config: BertConfig,
  633. input_shape: Tuple = (1, 1),
  634. seed: int = 0,
  635. dtype: jnp.dtype = jnp.float32,
  636. _do_init: bool = True,
  637. gradient_checkpointing: bool = False,
  638. **kwargs,
  639. ):
  640. module = self.module_class(
  641. config=config,
  642. dtype=dtype,
  643. gradient_checkpointing=gradient_checkpointing,
  644. **kwargs,
  645. )
  646. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  647. def enable_gradient_checkpointing(self):
  648. self._module = self.module_class(
  649. config=self.config,
  650. dtype=self.dtype,
  651. gradient_checkpointing=True,
  652. )
  653. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  654. # init input tensors
  655. input_ids = jnp.zeros(input_shape, dtype="i4")
  656. token_type_ids = jnp.zeros_like(input_ids)
  657. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
  658. attention_mask = jnp.ones_like(input_ids)
  659. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  660. params_rng, dropout_rng = jax.random.split(rng)
  661. rngs = {"params": params_rng, "dropout": dropout_rng}
  662. if self.config.add_cross_attention:
  663. encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
  664. encoder_attention_mask = attention_mask
  665. module_init_outputs = self.module.init(
  666. rngs,
  667. input_ids,
  668. attention_mask,
  669. token_type_ids,
  670. position_ids,
  671. head_mask,
  672. encoder_hidden_states,
  673. encoder_attention_mask,
  674. return_dict=False,
  675. )
  676. else:
  677. module_init_outputs = self.module.init(
  678. rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
  679. )
  680. random_params = module_init_outputs["params"]
  681. if params is not None:
  682. random_params = flatten_dict(unfreeze(random_params))
  683. params = flatten_dict(unfreeze(params))
  684. for missing_key in self._missing_keys:
  685. params[missing_key] = random_params[missing_key]
  686. self._missing_keys = set()
  687. return freeze(unflatten_dict(params))
  688. else:
  689. return random_params
  690. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
  691. def init_cache(self, batch_size, max_length):
  692. r"""
  693. Args:
  694. batch_size (`int`):
  695. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  696. max_length (`int`):
  697. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  698. cache.
  699. """
  700. # init input variables to retrieve cache
  701. input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  702. attention_mask = jnp.ones_like(input_ids, dtype="i4")
  703. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  704. init_variables = self.module.init(
  705. jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
  706. )
  707. return unfreeze(init_variables["cache"])
  708. @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  709. def __call__(
  710. self,
  711. input_ids,
  712. attention_mask=None,
  713. token_type_ids=None,
  714. position_ids=None,
  715. head_mask=None,
  716. encoder_hidden_states=None,
  717. encoder_attention_mask=None,
  718. params: dict = None,
  719. dropout_rng: jax.random.PRNGKey = None,
  720. train: bool = False,
  721. output_attentions: Optional[bool] = None,
  722. output_hidden_states: Optional[bool] = None,
  723. return_dict: Optional[bool] = None,
  724. past_key_values: dict = None,
  725. ):
  726. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  727. output_hidden_states = (
  728. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  729. )
  730. return_dict = return_dict if return_dict is not None else self.config.return_dict
  731. # init input tensors if not passed
  732. if token_type_ids is None:
  733. token_type_ids = jnp.zeros_like(input_ids)
  734. if position_ids is None:
  735. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  736. if attention_mask is None:
  737. attention_mask = jnp.ones_like(input_ids)
  738. if head_mask is None:
  739. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  740. # Handle any PRNG if needed
  741. rngs = {}
  742. if dropout_rng is not None:
  743. rngs["dropout"] = dropout_rng
  744. inputs = {"params": params or self.params}
  745. if self.config.add_cross_attention:
  746. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
  747. # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
  748. # changed by FlaxBertAttention module
  749. if past_key_values:
  750. inputs["cache"] = past_key_values
  751. mutable = ["cache"]
  752. else:
  753. mutable = False
  754. outputs = self.module.apply(
  755. inputs,
  756. jnp.array(input_ids, dtype="i4"),
  757. jnp.array(attention_mask, dtype="i4"),
  758. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  759. position_ids=jnp.array(position_ids, dtype="i4"),
  760. head_mask=jnp.array(head_mask, dtype="i4"),
  761. encoder_hidden_states=encoder_hidden_states,
  762. encoder_attention_mask=encoder_attention_mask,
  763. deterministic=not train,
  764. output_attentions=output_attentions,
  765. output_hidden_states=output_hidden_states,
  766. return_dict=return_dict,
  767. rngs=rngs,
  768. mutable=mutable,
  769. )
  770. # add updated cache to model output
  771. if past_key_values is not None and return_dict:
  772. outputs, past_key_values = outputs
  773. outputs["past_key_values"] = unfreeze(past_key_values["cache"])
  774. return outputs
  775. elif past_key_values is not None and not return_dict:
  776. outputs, past_key_values = outputs
  777. outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
  778. else:
  779. outputs = self.module.apply(
  780. inputs,
  781. jnp.array(input_ids, dtype="i4"),
  782. jnp.array(attention_mask, dtype="i4"),
  783. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  784. position_ids=jnp.array(position_ids, dtype="i4"),
  785. head_mask=jnp.array(head_mask, dtype="i4"),
  786. deterministic=not train,
  787. output_attentions=output_attentions,
  788. output_hidden_states=output_hidden_states,
  789. return_dict=return_dict,
  790. rngs=rngs,
  791. )
  792. return outputs
  793. class FlaxBertModule(nn.Module):
  794. config: BertConfig
  795. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  796. add_pooling_layer: bool = True
  797. gradient_checkpointing: bool = False
  798. def setup(self):
  799. self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)
  800. self.encoder = FlaxBertEncoder(
  801. self.config,
  802. dtype=self.dtype,
  803. gradient_checkpointing=self.gradient_checkpointing,
  804. )
  805. self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)
  806. def __call__(
  807. self,
  808. input_ids,
  809. attention_mask,
  810. token_type_ids: Optional[jnp.ndarray] = None,
  811. position_ids: Optional[jnp.ndarray] = None,
  812. head_mask: Optional[jnp.ndarray] = None,
  813. encoder_hidden_states: Optional[jnp.ndarray] = None,
  814. encoder_attention_mask: Optional[jnp.ndarray] = None,
  815. init_cache: bool = False,
  816. deterministic: bool = True,
  817. output_attentions: bool = False,
  818. output_hidden_states: bool = False,
  819. return_dict: bool = True,
  820. ):
  821. # make sure `token_type_ids` is correctly initialized when not passed
  822. if token_type_ids is None:
  823. token_type_ids = jnp.zeros_like(input_ids)
  824. # make sure `position_ids` is correctly initialized when not passed
  825. if position_ids is None:
  826. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  827. hidden_states = self.embeddings(
  828. input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
  829. )
  830. outputs = self.encoder(
  831. hidden_states,
  832. attention_mask,
  833. head_mask=head_mask,
  834. deterministic=deterministic,
  835. encoder_hidden_states=encoder_hidden_states,
  836. encoder_attention_mask=encoder_attention_mask,
  837. init_cache=init_cache,
  838. output_attentions=output_attentions,
  839. output_hidden_states=output_hidden_states,
  840. return_dict=return_dict,
  841. )
  842. hidden_states = outputs[0]
  843. pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
  844. if not return_dict:
  845. # if pooled is None, don't return it
  846. if pooled is None:
  847. return (hidden_states,) + outputs[1:]
  848. return (hidden_states, pooled) + outputs[1:]
  849. return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
  850. last_hidden_state=hidden_states,
  851. pooler_output=pooled,
  852. hidden_states=outputs.hidden_states,
  853. attentions=outputs.attentions,
  854. cross_attentions=outputs.cross_attentions,
  855. )
  856. @add_start_docstrings(
  857. "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
  858. BERT_START_DOCSTRING,
  859. )
  860. class FlaxBertModel(FlaxBertPreTrainedModel):
  861. module_class = FlaxBertModule
  862. append_call_sample_docstring(FlaxBertModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
  863. class FlaxBertForPreTrainingModule(nn.Module):
  864. config: BertConfig
  865. dtype: jnp.dtype = jnp.float32
  866. gradient_checkpointing: bool = False
  867. def setup(self):
  868. self.bert = FlaxBertModule(
  869. config=self.config,
  870. dtype=self.dtype,
  871. gradient_checkpointing=self.gradient_checkpointing,
  872. )
  873. self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype)
  874. def __call__(
  875. self,
  876. input_ids,
  877. attention_mask,
  878. token_type_ids,
  879. position_ids,
  880. head_mask,
  881. deterministic: bool = True,
  882. output_attentions: bool = False,
  883. output_hidden_states: bool = False,
  884. return_dict: bool = True,
  885. ):
  886. # Model
  887. outputs = self.bert(
  888. input_ids,
  889. attention_mask,
  890. token_type_ids,
  891. position_ids,
  892. head_mask,
  893. deterministic=deterministic,
  894. output_attentions=output_attentions,
  895. output_hidden_states=output_hidden_states,
  896. return_dict=return_dict,
  897. )
  898. if self.config.tie_word_embeddings:
  899. shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  900. else:
  901. shared_embedding = None
  902. hidden_states = outputs[0]
  903. pooled_output = outputs[1]
  904. prediction_scores, seq_relationship_score = self.cls(
  905. hidden_states, pooled_output, shared_embedding=shared_embedding
  906. )
  907. if not return_dict:
  908. return (prediction_scores, seq_relationship_score) + outputs[2:]
  909. return FlaxBertForPreTrainingOutput(
  910. prediction_logits=prediction_scores,
  911. seq_relationship_logits=seq_relationship_score,
  912. hidden_states=outputs.hidden_states,
  913. attentions=outputs.attentions,
  914. )
  915. @add_start_docstrings(
  916. """
  917. Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  918. sentence prediction (classification)` head.
  919. """,
  920. BERT_START_DOCSTRING,
  921. )
  922. class FlaxBertForPreTraining(FlaxBertPreTrainedModel):
  923. module_class = FlaxBertForPreTrainingModule
  924. FLAX_BERT_FOR_PRETRAINING_DOCSTRING = """
  925. Returns:
  926. Example:
  927. ```python
  928. >>> from transformers import AutoTokenizer, FlaxBertForPreTraining
  929. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  930. >>> model = FlaxBertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
  931. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
  932. >>> outputs = model(**inputs)
  933. >>> prediction_logits = outputs.prediction_logits
  934. >>> seq_relationship_logits = outputs.seq_relationship_logits
  935. ```
  936. """
  937. overwrite_call_docstring(
  938. FlaxBertForPreTraining,
  939. BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_PRETRAINING_DOCSTRING,
  940. )
  941. append_replace_return_docstrings(
  942. FlaxBertForPreTraining, output_type=FlaxBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
  943. )
  944. class FlaxBertForMaskedLMModule(nn.Module):
  945. config: BertConfig
  946. dtype: jnp.dtype = jnp.float32
  947. gradient_checkpointing: bool = False
  948. def setup(self):
  949. self.bert = FlaxBertModule(
  950. config=self.config,
  951. add_pooling_layer=False,
  952. dtype=self.dtype,
  953. gradient_checkpointing=self.gradient_checkpointing,
  954. )
  955. self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
  956. def __call__(
  957. self,
  958. input_ids,
  959. attention_mask,
  960. token_type_ids,
  961. position_ids,
  962. head_mask,
  963. deterministic: bool = True,
  964. output_attentions: bool = False,
  965. output_hidden_states: bool = False,
  966. return_dict: bool = True,
  967. ):
  968. # Model
  969. outputs = self.bert(
  970. input_ids,
  971. attention_mask,
  972. token_type_ids,
  973. position_ids,
  974. head_mask,
  975. deterministic=deterministic,
  976. output_attentions=output_attentions,
  977. output_hidden_states=output_hidden_states,
  978. return_dict=return_dict,
  979. )
  980. hidden_states = outputs[0]
  981. if self.config.tie_word_embeddings:
  982. shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  983. else:
  984. shared_embedding = None
  985. # Compute the prediction scores
  986. logits = self.cls(hidden_states, shared_embedding=shared_embedding)
  987. if not return_dict:
  988. return (logits,) + outputs[1:]
  989. return FlaxMaskedLMOutput(
  990. logits=logits,
  991. hidden_states=outputs.hidden_states,
  992. attentions=outputs.attentions,
  993. )
  994. @add_start_docstrings("""Bert Model with a `language modeling` head on top.""", BERT_START_DOCSTRING)
  995. class FlaxBertForMaskedLM(FlaxBertPreTrainedModel):
  996. module_class = FlaxBertForMaskedLMModule
  997. append_call_sample_docstring(FlaxBertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
  998. class FlaxBertForNextSentencePredictionModule(nn.Module):
  999. config: BertConfig
  1000. dtype: jnp.dtype = jnp.float32
  1001. gradient_checkpointing: bool = False
  1002. def setup(self):
  1003. self.bert = FlaxBertModule(
  1004. config=self.config,
  1005. dtype=self.dtype,
  1006. gradient_checkpointing=self.gradient_checkpointing,
  1007. )
  1008. self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype)
  1009. def __call__(
  1010. self,
  1011. input_ids,
  1012. attention_mask,
  1013. token_type_ids,
  1014. position_ids,
  1015. head_mask,
  1016. deterministic: bool = True,
  1017. output_attentions: bool = False,
  1018. output_hidden_states: bool = False,
  1019. return_dict: bool = True,
  1020. ):
  1021. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1022. # Model
  1023. outputs = self.bert(
  1024. input_ids,
  1025. attention_mask,
  1026. token_type_ids,
  1027. position_ids,
  1028. head_mask,
  1029. deterministic=deterministic,
  1030. output_attentions=output_attentions,
  1031. output_hidden_states=output_hidden_states,
  1032. return_dict=return_dict,
  1033. )
  1034. pooled_output = outputs[1]
  1035. seq_relationship_scores = self.cls(pooled_output)
  1036. if not return_dict:
  1037. return (seq_relationship_scores,) + outputs[2:]
  1038. return FlaxNextSentencePredictorOutput(
  1039. logits=seq_relationship_scores,
  1040. hidden_states=outputs.hidden_states,
  1041. attentions=outputs.attentions,
  1042. )
  1043. @add_start_docstrings(
  1044. """Bert Model with a `next sentence prediction (classification)` head on top.""",
  1045. BERT_START_DOCSTRING,
  1046. )
  1047. class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel):
  1048. module_class = FlaxBertForNextSentencePredictionModule
  1049. FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING = """
  1050. Returns:
  1051. Example:
  1052. ```python
  1053. >>> from transformers import AutoTokenizer, FlaxBertForNextSentencePrediction
  1054. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1055. >>> model = FlaxBertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
  1056. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1057. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  1058. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="jax")
  1059. >>> outputs = model(**encoding)
  1060. >>> logits = outputs.logits
  1061. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  1062. ```
  1063. """
  1064. overwrite_call_docstring(
  1065. FlaxBertForNextSentencePrediction,
  1066. BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_BERT_FOR_NEXT_SENT_PRED_DOCSTRING,
  1067. )
  1068. append_replace_return_docstrings(
  1069. FlaxBertForNextSentencePrediction, output_type=FlaxNextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
  1070. )
  1071. class FlaxBertForSequenceClassificationModule(nn.Module):
  1072. config: BertConfig
  1073. dtype: jnp.dtype = jnp.float32
  1074. gradient_checkpointing: bool = False
  1075. def setup(self):
  1076. self.bert = FlaxBertModule(
  1077. config=self.config,
  1078. dtype=self.dtype,
  1079. gradient_checkpointing=self.gradient_checkpointing,
  1080. )
  1081. classifier_dropout = (
  1082. self.config.classifier_dropout
  1083. if self.config.classifier_dropout is not None
  1084. else self.config.hidden_dropout_prob
  1085. )
  1086. self.dropout = nn.Dropout(rate=classifier_dropout)
  1087. self.classifier = nn.Dense(
  1088. self.config.num_labels,
  1089. dtype=self.dtype,
  1090. )
  1091. def __call__(
  1092. self,
  1093. input_ids,
  1094. attention_mask,
  1095. token_type_ids,
  1096. position_ids,
  1097. head_mask,
  1098. deterministic: bool = True,
  1099. output_attentions: bool = False,
  1100. output_hidden_states: bool = False,
  1101. return_dict: bool = True,
  1102. ):
  1103. # Model
  1104. outputs = self.bert(
  1105. input_ids,
  1106. attention_mask,
  1107. token_type_ids,
  1108. position_ids,
  1109. head_mask,
  1110. deterministic=deterministic,
  1111. output_attentions=output_attentions,
  1112. output_hidden_states=output_hidden_states,
  1113. return_dict=return_dict,
  1114. )
  1115. pooled_output = outputs[1]
  1116. pooled_output = self.dropout(pooled_output, deterministic=deterministic)
  1117. logits = self.classifier(pooled_output)
  1118. if not return_dict:
  1119. return (logits,) + outputs[2:]
  1120. return FlaxSequenceClassifierOutput(
  1121. logits=logits,
  1122. hidden_states=outputs.hidden_states,
  1123. attentions=outputs.attentions,
  1124. )
  1125. @add_start_docstrings(
  1126. """
  1127. Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  1128. output) e.g. for GLUE tasks.
  1129. """,
  1130. BERT_START_DOCSTRING,
  1131. )
  1132. class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel):
  1133. module_class = FlaxBertForSequenceClassificationModule
  1134. append_call_sample_docstring(
  1135. FlaxBertForSequenceClassification,
  1136. _CHECKPOINT_FOR_DOC,
  1137. FlaxSequenceClassifierOutput,
  1138. _CONFIG_FOR_DOC,
  1139. )
  1140. class FlaxBertForMultipleChoiceModule(nn.Module):
  1141. config: BertConfig
  1142. dtype: jnp.dtype = jnp.float32
  1143. gradient_checkpointing: bool = False
  1144. def setup(self):
  1145. self.bert = FlaxBertModule(
  1146. config=self.config,
  1147. dtype=self.dtype,
  1148. gradient_checkpointing=self.gradient_checkpointing,
  1149. )
  1150. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  1151. self.classifier = nn.Dense(1, dtype=self.dtype)
  1152. def __call__(
  1153. self,
  1154. input_ids,
  1155. attention_mask,
  1156. token_type_ids,
  1157. position_ids,
  1158. head_mask,
  1159. deterministic: bool = True,
  1160. output_attentions: bool = False,
  1161. output_hidden_states: bool = False,
  1162. return_dict: bool = True,
  1163. ):
  1164. num_choices = input_ids.shape[1]
  1165. input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
  1166. attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
  1167. token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
  1168. position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
  1169. # Model
  1170. outputs = self.bert(
  1171. input_ids,
  1172. attention_mask,
  1173. token_type_ids,
  1174. position_ids,
  1175. head_mask,
  1176. deterministic=deterministic,
  1177. output_attentions=output_attentions,
  1178. output_hidden_states=output_hidden_states,
  1179. return_dict=return_dict,
  1180. )
  1181. pooled_output = outputs[1]
  1182. pooled_output = self.dropout(pooled_output, deterministic=deterministic)
  1183. logits = self.classifier(pooled_output)
  1184. reshaped_logits = logits.reshape(-1, num_choices)
  1185. if not return_dict:
  1186. return (reshaped_logits,) + outputs[2:]
  1187. return FlaxMultipleChoiceModelOutput(
  1188. logits=reshaped_logits,
  1189. hidden_states=outputs.hidden_states,
  1190. attentions=outputs.attentions,
  1191. )
  1192. @add_start_docstrings(
  1193. """
  1194. Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1195. softmax) e.g. for RocStories/SWAG tasks.
  1196. """,
  1197. BERT_START_DOCSTRING,
  1198. )
  1199. class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel):
  1200. module_class = FlaxBertForMultipleChoiceModule
  1201. overwrite_call_docstring(
  1202. FlaxBertForMultipleChoice, BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  1203. )
  1204. append_call_sample_docstring(
  1205. FlaxBertForMultipleChoice, _CHECKPOINT_FOR_DOC, FlaxMultipleChoiceModelOutput, _CONFIG_FOR_DOC
  1206. )
  1207. class FlaxBertForTokenClassificationModule(nn.Module):
  1208. config: BertConfig
  1209. dtype: jnp.dtype = jnp.float32
  1210. gradient_checkpointing: bool = False
  1211. def setup(self):
  1212. self.bert = FlaxBertModule(
  1213. config=self.config,
  1214. dtype=self.dtype,
  1215. add_pooling_layer=False,
  1216. gradient_checkpointing=self.gradient_checkpointing,
  1217. )
  1218. classifier_dropout = (
  1219. self.config.classifier_dropout
  1220. if self.config.classifier_dropout is not None
  1221. else self.config.hidden_dropout_prob
  1222. )
  1223. self.dropout = nn.Dropout(rate=classifier_dropout)
  1224. self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1225. def __call__(
  1226. self,
  1227. input_ids,
  1228. attention_mask,
  1229. token_type_ids,
  1230. position_ids,
  1231. head_mask,
  1232. deterministic: bool = True,
  1233. output_attentions: bool = False,
  1234. output_hidden_states: bool = False,
  1235. return_dict: bool = True,
  1236. ):
  1237. # Model
  1238. outputs = self.bert(
  1239. input_ids,
  1240. attention_mask,
  1241. token_type_ids,
  1242. position_ids,
  1243. head_mask,
  1244. deterministic=deterministic,
  1245. output_attentions=output_attentions,
  1246. output_hidden_states=output_hidden_states,
  1247. return_dict=return_dict,
  1248. )
  1249. hidden_states = outputs[0]
  1250. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  1251. logits = self.classifier(hidden_states)
  1252. if not return_dict:
  1253. return (logits,) + outputs[1:]
  1254. return FlaxTokenClassifierOutput(
  1255. logits=logits,
  1256. hidden_states=outputs.hidden_states,
  1257. attentions=outputs.attentions,
  1258. )
  1259. @add_start_docstrings(
  1260. """
  1261. Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1262. Named-Entity-Recognition (NER) tasks.
  1263. """,
  1264. BERT_START_DOCSTRING,
  1265. )
  1266. class FlaxBertForTokenClassification(FlaxBertPreTrainedModel):
  1267. module_class = FlaxBertForTokenClassificationModule
  1268. append_call_sample_docstring(
  1269. FlaxBertForTokenClassification, _CHECKPOINT_FOR_DOC, FlaxTokenClassifierOutput, _CONFIG_FOR_DOC
  1270. )
  1271. class FlaxBertForQuestionAnsweringModule(nn.Module):
  1272. config: BertConfig
  1273. dtype: jnp.dtype = jnp.float32
  1274. gradient_checkpointing: bool = False
  1275. def setup(self):
  1276. self.bert = FlaxBertModule(
  1277. config=self.config,
  1278. dtype=self.dtype,
  1279. add_pooling_layer=False,
  1280. gradient_checkpointing=self.gradient_checkpointing,
  1281. )
  1282. self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1283. def __call__(
  1284. self,
  1285. input_ids,
  1286. attention_mask,
  1287. token_type_ids,
  1288. position_ids,
  1289. head_mask,
  1290. deterministic: bool = True,
  1291. output_attentions: bool = False,
  1292. output_hidden_states: bool = False,
  1293. return_dict: bool = True,
  1294. ):
  1295. # Model
  1296. outputs = self.bert(
  1297. input_ids,
  1298. attention_mask,
  1299. token_type_ids,
  1300. position_ids,
  1301. head_mask,
  1302. deterministic=deterministic,
  1303. output_attentions=output_attentions,
  1304. output_hidden_states=output_hidden_states,
  1305. return_dict=return_dict,
  1306. )
  1307. hidden_states = outputs[0]
  1308. logits = self.qa_outputs(hidden_states)
  1309. start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
  1310. start_logits = start_logits.squeeze(-1)
  1311. end_logits = end_logits.squeeze(-1)
  1312. if not return_dict:
  1313. return (start_logits, end_logits) + outputs[1:]
  1314. return FlaxQuestionAnsweringModelOutput(
  1315. start_logits=start_logits,
  1316. end_logits=end_logits,
  1317. hidden_states=outputs.hidden_states,
  1318. attentions=outputs.attentions,
  1319. )
  1320. @add_start_docstrings(
  1321. """
  1322. Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1323. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1324. """,
  1325. BERT_START_DOCSTRING,
  1326. )
  1327. class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel):
  1328. module_class = FlaxBertForQuestionAnsweringModule
  1329. append_call_sample_docstring(
  1330. FlaxBertForQuestionAnswering,
  1331. _CHECKPOINT_FOR_DOC,
  1332. FlaxQuestionAnsweringModelOutput,
  1333. _CONFIG_FOR_DOC,
  1334. )
  1335. class FlaxBertForCausalLMModule(nn.Module):
  1336. config: BertConfig
  1337. dtype: jnp.dtype = jnp.float32
  1338. gradient_checkpointing: bool = False
  1339. def setup(self):
  1340. self.bert = FlaxBertModule(
  1341. config=self.config,
  1342. add_pooling_layer=False,
  1343. dtype=self.dtype,
  1344. gradient_checkpointing=self.gradient_checkpointing,
  1345. )
  1346. self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)
  1347. def __call__(
  1348. self,
  1349. input_ids,
  1350. attention_mask,
  1351. position_ids,
  1352. token_type_ids: Optional[jnp.ndarray] = None,
  1353. head_mask: Optional[jnp.ndarray] = None,
  1354. encoder_hidden_states: Optional[jnp.ndarray] = None,
  1355. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1356. init_cache: bool = False,
  1357. deterministic: bool = True,
  1358. output_attentions: bool = False,
  1359. output_hidden_states: bool = False,
  1360. return_dict: bool = True,
  1361. ):
  1362. # Model
  1363. outputs = self.bert(
  1364. input_ids,
  1365. attention_mask,
  1366. token_type_ids,
  1367. position_ids,
  1368. head_mask,
  1369. encoder_hidden_states=encoder_hidden_states,
  1370. encoder_attention_mask=encoder_attention_mask,
  1371. init_cache=init_cache,
  1372. deterministic=deterministic,
  1373. output_attentions=output_attentions,
  1374. output_hidden_states=output_hidden_states,
  1375. return_dict=return_dict,
  1376. )
  1377. hidden_states = outputs[0]
  1378. if self.config.tie_word_embeddings:
  1379. shared_embedding = self.bert.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  1380. else:
  1381. shared_embedding = None
  1382. # Compute the prediction scores
  1383. logits = self.cls(hidden_states, shared_embedding=shared_embedding)
  1384. if not return_dict:
  1385. return (logits,) + outputs[1:]
  1386. return FlaxCausalLMOutputWithCrossAttentions(
  1387. logits=logits,
  1388. hidden_states=outputs.hidden_states,
  1389. attentions=outputs.attentions,
  1390. cross_attentions=outputs.cross_attentions,
  1391. )
  1392. @add_start_docstrings(
  1393. """
  1394. Bert Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
  1395. autoregressive tasks.
  1396. """,
  1397. BERT_START_DOCSTRING,
  1398. )
  1399. class FlaxBertForCausalLM(FlaxBertPreTrainedModel):
  1400. module_class = FlaxBertForCausalLMModule
  1401. def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
  1402. # initializing the cache
  1403. batch_size, seq_length = input_ids.shape
  1404. past_key_values = self.init_cache(batch_size, max_length)
  1405. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  1406. # But since the decoder uses a causal mask, those positions are masked anyway.
  1407. # Thus, we can create a single static attention_mask here, which is more efficient for compilation
  1408. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  1409. if attention_mask is not None:
  1410. position_ids = attention_mask.cumsum(axis=-1) - 1
  1411. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
  1412. else:
  1413. position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
  1414. return {
  1415. "past_key_values": past_key_values,
  1416. "attention_mask": extended_attention_mask,
  1417. "position_ids": position_ids,
  1418. }
  1419. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  1420. model_kwargs["past_key_values"] = model_outputs.past_key_values
  1421. model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
  1422. return model_kwargs
  1423. append_call_sample_docstring(
  1424. FlaxBertForCausalLM,
  1425. _CHECKPOINT_FOR_DOC,
  1426. FlaxCausalLMOutputWithCrossAttentions,
  1427. _CONFIG_FOR_DOC,
  1428. )