modeling_flax_bart.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Flax Bart model."""
  16. import math
  17. import random
  18. from functools import partial
  19. from typing import Callable, Optional, Tuple
  20. import flax.linen as nn
  21. import jax
  22. import jax.numpy as jnp
  23. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  24. from flax.linen import combine_masks, make_causal_mask
  25. from flax.linen.attention import dot_product_attention_weights
  26. from flax.traverse_util import flatten_dict, unflatten_dict
  27. from jax import lax
  28. from jax.random import PRNGKey
  29. from ...modeling_flax_outputs import (
  30. FlaxBaseModelOutput,
  31. FlaxBaseModelOutputWithPastAndCrossAttentions,
  32. FlaxCausalLMOutputWithCrossAttentions,
  33. FlaxSeq2SeqLMOutput,
  34. FlaxSeq2SeqModelOutput,
  35. FlaxSeq2SeqQuestionAnsweringModelOutput,
  36. FlaxSeq2SeqSequenceClassifierOutput,
  37. )
  38. from ...modeling_flax_utils import (
  39. ACT2FN,
  40. FlaxPreTrainedModel,
  41. append_call_sample_docstring,
  42. append_replace_return_docstrings,
  43. overwrite_call_docstring,
  44. )
  45. from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
  46. from .configuration_bart import BartConfig
  47. logger = logging.get_logger(__name__)
  48. _CHECKPOINT_FOR_DOC = "facebook/bart-base"
  49. _CONFIG_FOR_DOC = "BartConfig"
  50. BART_START_DOCSTRING = r"""
  51. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  52. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  53. etc.)
  54. This model is also a Flax Linen
  55. [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
  56. regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
  57. Finally, this model supports inherent JAX features such as:
  58. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  59. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  60. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  61. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  62. Parameters:
  63. config ([`BartConfig`]): Model configuration class with all the parameters of the model.
  64. Initializing with a config file does not load the weights associated with the model, only the
  65. configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
  66. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
  67. The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
  68. `jax.numpy.bfloat16` (on TPUs).
  69. This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
  70. specified all the computation will be performed with the given `dtype`.
  71. **Note that this only specifies the dtype of the computation and does not influence the dtype of model
  72. parameters.**
  73. If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
  74. [`~FlaxPreTrainedModel.to_bf16`].
  75. """
  76. BART_INPUTS_DOCSTRING = r"""
  77. Args:
  78. input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  79. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  80. it.
  81. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  82. [`PreTrainedTokenizer.__call__`] for details.
  83. [What are input IDs?](../glossary#input-ids)
  84. attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  85. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  86. - 1 for tokens that are **not masked**,
  87. - 0 for tokens that are **masked**.
  88. [What are attention masks?](../glossary#attention-mask)
  89. decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  90. Indices of decoder input sequence tokens in the vocabulary.
  91. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  92. [`PreTrainedTokenizer.__call__`] for details.
  93. [What are decoder input IDs?](../glossary#decoder-input-ids)
  94. For translation and summarization training, `decoder_input_ids` should be provided. If no
  95. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  96. for denoising pre-training following the paper.
  97. decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  98. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  99. be used by default.
  100. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
  101. paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
  102. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  103. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  104. config.max_position_embeddings - 1]`.
  105. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  106. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
  107. range `[0, config.max_position_embeddings - 1]`.
  108. output_attentions (`bool`, *optional*):
  109. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  110. tensors for more detail.
  111. output_hidden_states (`bool`, *optional*):
  112. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  113. more detail.
  114. return_dict (`bool`, *optional*):
  115. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  116. """
  117. BART_ENCODE_INPUTS_DOCSTRING = r"""
  118. Args:
  119. input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  120. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  121. it.
  122. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  123. [`PreTrainedTokenizer.__call__`] for details.
  124. [What are input IDs?](../glossary#input-ids)
  125. attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  126. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  127. - 1 for tokens that are **not masked**,
  128. - 0 for tokens that are **masked**.
  129. [What are attention masks?](../glossary#attention-mask)
  130. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  131. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  132. config.max_position_embeddings - 1]`.
  133. output_attentions (`bool`, *optional*):
  134. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  135. tensors for more detail.
  136. output_hidden_states (`bool`, *optional*):
  137. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  138. more detail.
  139. return_dict (`bool`, *optional*):
  140. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  141. """
  142. BART_DECODE_INPUTS_DOCSTRING = r"""
  143. Args:
  144. decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`):
  145. Indices of decoder input sequence tokens in the vocabulary.
  146. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  147. [`PreTrainedTokenizer.__call__`] for details.
  148. [What are decoder input IDs?](../glossary#decoder-input-ids)
  149. For translation and summarization training, `decoder_input_ids` should be provided. If no
  150. `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
  151. for denoising pre-training following the paper.
  152. encoder_outputs (`tuple(tuple(jnp.ndarray)`):
  153. Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
  154. `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
  155. hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
  156. encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  157. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  158. - 1 for tokens that are **not masked**,
  159. - 0 for tokens that are **masked**.
  160. [What are attention masks?](../glossary#attention-mask)
  161. decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
  162. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  163. be used by default.
  164. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the
  165. paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
  166. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
  167. Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
  168. range `[0, config.max_position_embeddings - 1]`.
  169. past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
  170. Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
  171. auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
  172. output_attentions (`bool`, *optional*):
  173. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  174. tensors for more detail.
  175. output_hidden_states (`bool`, *optional*):
  176. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  177. more detail.
  178. return_dict (`bool`, *optional*):
  179. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  180. """
  181. def shift_tokens_right(input_ids: jnp.ndarray, pad_token_id: int, decoder_start_token_id: int) -> jnp.ndarray:
  182. """
  183. Shift input ids one token to the right.
  184. """
  185. shifted_input_ids = jnp.zeros_like(input_ids)
  186. shifted_input_ids = shifted_input_ids.at[:, 1:].set(input_ids[:, :-1])
  187. shifted_input_ids = shifted_input_ids.at[:, 0].set(decoder_start_token_id)
  188. shifted_input_ids = jnp.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
  189. return shifted_input_ids
  190. class FlaxBartAttention(nn.Module):
  191. config: BartConfig
  192. embed_dim: int
  193. num_heads: int
  194. dropout: float = 0.0
  195. causal: bool = False
  196. bias: bool = True
  197. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  198. def setup(self) -> None:
  199. self.head_dim = self.embed_dim // self.num_heads
  200. if self.head_dim * self.num_heads != self.embed_dim:
  201. raise ValueError(
  202. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  203. f" and `num_heads`: {self.num_heads})."
  204. )
  205. dense = partial(
  206. nn.Dense,
  207. self.embed_dim,
  208. use_bias=self.bias,
  209. dtype=self.dtype,
  210. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  211. )
  212. self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
  213. self.out_proj = dense()
  214. self.dropout_layer = nn.Dropout(rate=self.dropout)
  215. if self.causal:
  216. self.causal_mask = make_causal_mask(
  217. jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
  218. )
  219. def _split_heads(self, hidden_states):
  220. return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
  221. def _merge_heads(self, hidden_states):
  222. return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
  223. @nn.compact
  224. def _concatenate_to_cache(self, key, value, query, attention_mask):
  225. """
  226. This function takes projected key, value states from a single input token and concatenates the states to cached
  227. states from previous steps. This function is slighly adapted from the official Flax repository:
  228. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  229. """
  230. # detect if we're initializing by absence of existing cache data.
  231. is_initialized = self.has_variable("cache", "cached_key")
  232. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  233. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  234. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  235. if is_initialized:
  236. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  237. # update key, value caches with our new 1d spatial slices
  238. cur_index = cache_index.value
  239. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  240. key = lax.dynamic_update_slice(cached_key.value, key, indices)
  241. value = lax.dynamic_update_slice(cached_value.value, value, indices)
  242. cached_key.value = key
  243. cached_value.value = value
  244. num_updated_cache_vectors = query.shape[1]
  245. cache_index.value = cache_index.value + num_updated_cache_vectors
  246. # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
  247. pad_mask = jnp.broadcast_to(
  248. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  249. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  250. )
  251. attention_mask = combine_masks(pad_mask, attention_mask)
  252. return key, value, attention_mask
  253. def __call__(
  254. self,
  255. hidden_states: jnp.ndarray,
  256. key_value_states: Optional[jnp.ndarray] = None,
  257. attention_mask: Optional[jnp.ndarray] = None,
  258. init_cache: bool = False,
  259. deterministic: bool = True,
  260. ) -> Tuple[jnp.ndarray]:
  261. """Input shape: Batch x Time x Channel"""
  262. # if key_value_states are provided this layer is used as a cross-attention layer
  263. # for the decoder
  264. is_cross_attention = key_value_states is not None
  265. batch_size = hidden_states.shape[0]
  266. # get query proj
  267. query_states = self.q_proj(hidden_states)
  268. # get key, value proj
  269. if is_cross_attention:
  270. # cross_attentions
  271. key_states = self.k_proj(key_value_states)
  272. value_states = self.v_proj(key_value_states)
  273. else:
  274. # self_attention
  275. key_states = self.k_proj(hidden_states)
  276. value_states = self.v_proj(hidden_states)
  277. query_states = self._split_heads(query_states)
  278. key_states = self._split_heads(key_states)
  279. value_states = self._split_heads(value_states)
  280. # handle cache prepare causal attention mask
  281. if self.causal:
  282. query_length, key_length = query_states.shape[1], key_states.shape[1]
  283. if self.has_variable("cache", "cached_key"):
  284. mask_shift = self.variables["cache"]["cache_index"]
  285. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  286. causal_mask = lax.dynamic_slice(
  287. self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
  288. )
  289. else:
  290. causal_mask = self.causal_mask[:, :, :query_length, :key_length]
  291. causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
  292. # combine masks if needed
  293. if attention_mask is not None and self.causal:
  294. attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
  295. attention_mask = combine_masks(attention_mask, causal_mask)
  296. elif self.causal:
  297. attention_mask = causal_mask
  298. elif attention_mask is not None:
  299. attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
  300. # During fast autoregressive decoding, we feed one position at a time,
  301. # and cache the keys and values step by step.
  302. if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
  303. key_states, value_states, attention_mask = self._concatenate_to_cache(
  304. key_states, value_states, query_states, attention_mask
  305. )
  306. # Convert the boolean attention mask to an attention bias.
  307. if attention_mask is not None:
  308. # attention mask in the form of attention bias
  309. attention_bias = lax.select(
  310. attention_mask > 0,
  311. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  312. jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
  313. )
  314. else:
  315. attention_bias = None
  316. dropout_rng = None
  317. if not deterministic and self.dropout > 0.0:
  318. dropout_rng = self.make_rng("dropout")
  319. attn_weights = dot_product_attention_weights(
  320. query_states,
  321. key_states,
  322. bias=attention_bias,
  323. dropout_rng=dropout_rng,
  324. dropout_rate=self.dropout,
  325. broadcast_dropout=True,
  326. deterministic=deterministic,
  327. dtype=self.dtype,
  328. precision=None,
  329. )
  330. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  331. attn_output = self._merge_heads(attn_output)
  332. attn_output = self.out_proj(attn_output)
  333. return attn_output, attn_weights
  334. class FlaxBartEncoderLayer(nn.Module):
  335. config: BartConfig
  336. dtype: jnp.dtype = jnp.float32
  337. def setup(self) -> None:
  338. self.embed_dim = self.config.d_model
  339. self.self_attn = FlaxBartAttention(
  340. config=self.config,
  341. embed_dim=self.embed_dim,
  342. num_heads=self.config.encoder_attention_heads,
  343. dropout=self.config.attention_dropout,
  344. dtype=self.dtype,
  345. )
  346. self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  347. self.dropout_layer = nn.Dropout(rate=self.config.dropout)
  348. self.activation_fn = ACT2FN[self.config.activation_function]
  349. self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
  350. self.fc1 = nn.Dense(
  351. self.config.encoder_ffn_dim,
  352. dtype=self.dtype,
  353. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  354. )
  355. self.fc2 = nn.Dense(
  356. self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
  357. )
  358. self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  359. def __call__(
  360. self,
  361. hidden_states: jnp.ndarray,
  362. attention_mask: jnp.ndarray,
  363. output_attentions: bool = True,
  364. deterministic: bool = True,
  365. ) -> Tuple[jnp.ndarray]:
  366. residual = hidden_states
  367. hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask)
  368. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  369. hidden_states = residual + hidden_states
  370. hidden_states = self.self_attn_layer_norm(hidden_states)
  371. residual = hidden_states
  372. hidden_states = self.activation_fn(self.fc1(hidden_states))
  373. hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
  374. hidden_states = self.fc2(hidden_states)
  375. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  376. hidden_states = residual + hidden_states
  377. hidden_states = self.final_layer_norm(hidden_states)
  378. outputs = (hidden_states,)
  379. if output_attentions:
  380. outputs += (attn_weights,)
  381. return outputs
  382. class FlaxBartEncoderLayerCollection(nn.Module):
  383. config: BartConfig
  384. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  385. def setup(self):
  386. self.layers = [
  387. FlaxBartEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers)
  388. ]
  389. self.layerdrop = self.config.encoder_layerdrop
  390. def __call__(
  391. self,
  392. hidden_states,
  393. attention_mask,
  394. deterministic: bool = True,
  395. output_attentions: bool = False,
  396. output_hidden_states: bool = False,
  397. return_dict: bool = True,
  398. ):
  399. all_attentions = () if output_attentions else None
  400. all_hidden_states = () if output_hidden_states else None
  401. for encoder_layer in self.layers:
  402. if output_hidden_states:
  403. all_hidden_states = all_hidden_states + (hidden_states,)
  404. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  405. dropout_probability = random.uniform(0, 1)
  406. if not deterministic and (dropout_probability < self.layerdrop): # skip the layer
  407. layer_outputs = (None, None)
  408. else:
  409. layer_outputs = encoder_layer(
  410. hidden_states,
  411. attention_mask,
  412. output_attentions,
  413. deterministic,
  414. )
  415. hidden_states = layer_outputs[0]
  416. if output_attentions:
  417. all_attentions = all_attentions + (layer_outputs[1],)
  418. if output_hidden_states:
  419. all_hidden_states += (hidden_states,)
  420. outputs = (hidden_states, all_hidden_states, all_attentions)
  421. if not return_dict:
  422. return tuple(v for v in outputs if v is not None)
  423. return FlaxBaseModelOutput(
  424. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  425. )
  426. class FlaxBartDecoderLayer(nn.Module):
  427. config: BartConfig
  428. dtype: jnp.dtype = jnp.float32
  429. def setup(self) -> None:
  430. self.embed_dim = self.config.d_model
  431. self.self_attn = FlaxBartAttention(
  432. config=self.config,
  433. embed_dim=self.embed_dim,
  434. num_heads=self.config.decoder_attention_heads,
  435. dropout=self.config.attention_dropout,
  436. causal=True,
  437. dtype=self.dtype,
  438. )
  439. self.dropout_layer = nn.Dropout(rate=self.config.dropout)
  440. self.activation_fn = ACT2FN[self.config.activation_function]
  441. self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
  442. self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  443. self.encoder_attn = FlaxBartAttention(
  444. config=self.config,
  445. embed_dim=self.embed_dim,
  446. num_heads=self.config.decoder_attention_heads,
  447. dropout=self.config.attention_dropout,
  448. dtype=self.dtype,
  449. )
  450. self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  451. self.fc1 = nn.Dense(
  452. self.config.decoder_ffn_dim,
  453. dtype=self.dtype,
  454. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  455. )
  456. self.fc2 = nn.Dense(
  457. self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
  458. )
  459. self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  460. def __call__(
  461. self,
  462. hidden_states: jnp.ndarray,
  463. attention_mask: jnp.ndarray,
  464. encoder_hidden_states: Optional[jnp.ndarray] = None,
  465. encoder_attention_mask: Optional[jnp.ndarray] = None,
  466. init_cache: bool = False,
  467. output_attentions: bool = True,
  468. deterministic: bool = True,
  469. ) -> Tuple[jnp.ndarray]:
  470. residual = hidden_states
  471. # Self Attention
  472. hidden_states, self_attn_weights = self.self_attn(
  473. hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache
  474. )
  475. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  476. hidden_states = residual + hidden_states
  477. hidden_states = self.self_attn_layer_norm(hidden_states)
  478. # Cross-Attention Block
  479. cross_attn_weights = None
  480. if encoder_hidden_states is not None:
  481. residual = hidden_states
  482. hidden_states, cross_attn_weights = self.encoder_attn(
  483. hidden_states=hidden_states,
  484. key_value_states=encoder_hidden_states,
  485. attention_mask=encoder_attention_mask,
  486. )
  487. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  488. hidden_states = residual + hidden_states
  489. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  490. # Fully Connected
  491. residual = hidden_states
  492. hidden_states = self.activation_fn(self.fc1(hidden_states))
  493. hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic)
  494. hidden_states = self.fc2(hidden_states)
  495. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  496. hidden_states = residual + hidden_states
  497. hidden_states = self.final_layer_norm(hidden_states)
  498. outputs = (hidden_states,)
  499. if output_attentions:
  500. outputs += (self_attn_weights, cross_attn_weights)
  501. return outputs
  502. class FlaxBartDecoderLayerCollection(nn.Module):
  503. config: BartConfig
  504. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  505. def setup(self):
  506. self.layers = [
  507. FlaxBartDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers)
  508. ]
  509. self.layerdrop = self.config.decoder_layerdrop
  510. def __call__(
  511. self,
  512. hidden_states,
  513. attention_mask,
  514. encoder_hidden_states: Optional[jnp.ndarray] = None,
  515. encoder_attention_mask: Optional[jnp.ndarray] = None,
  516. deterministic: bool = True,
  517. init_cache: bool = False,
  518. output_attentions: bool = False,
  519. output_hidden_states: bool = False,
  520. return_dict: bool = True,
  521. ):
  522. # decoder layers
  523. all_hidden_states = () if output_hidden_states else None
  524. all_self_attns = () if output_attentions else None
  525. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  526. for decoder_layer in self.layers:
  527. if output_hidden_states:
  528. all_hidden_states += (hidden_states,)
  529. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  530. dropout_probability = random.uniform(0, 1)
  531. if not deterministic and (dropout_probability < self.layerdrop):
  532. layer_outputs = (None, None, None)
  533. else:
  534. layer_outputs = decoder_layer(
  535. hidden_states,
  536. attention_mask=attention_mask,
  537. encoder_hidden_states=encoder_hidden_states,
  538. encoder_attention_mask=encoder_attention_mask,
  539. init_cache=init_cache,
  540. output_attentions=output_attentions,
  541. deterministic=deterministic,
  542. )
  543. hidden_states = layer_outputs[0]
  544. if output_attentions:
  545. all_self_attns += (layer_outputs[1],)
  546. if encoder_hidden_states is not None:
  547. all_cross_attentions += (layer_outputs[2],)
  548. # add hidden states from the last decoder layer
  549. if output_hidden_states:
  550. all_hidden_states += (hidden_states,)
  551. outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions]
  552. if not return_dict:
  553. return tuple(v for v in outputs if v is not None)
  554. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  555. last_hidden_state=hidden_states,
  556. hidden_states=all_hidden_states,
  557. attentions=all_self_attns,
  558. cross_attentions=all_cross_attentions,
  559. )
  560. class FlaxBartClassificationHead(nn.Module):
  561. """Head for sentence-level classification tasks."""
  562. config: BartConfig
  563. inner_dim: int
  564. num_classes: int
  565. pooler_dropout: float
  566. dtype: jnp.dtype = jnp.float32
  567. def setup(self):
  568. self.dense = nn.Dense(
  569. self.inner_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
  570. )
  571. self.dropout = nn.Dropout(rate=self.pooler_dropout)
  572. self.out_proj = nn.Dense(
  573. self.num_classes,
  574. dtype=self.dtype,
  575. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  576. )
  577. def __call__(self, hidden_states: jnp.ndarray, deterministic: bool):
  578. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  579. hidden_states = self.dense(hidden_states)
  580. hidden_states = jnp.tanh(hidden_states)
  581. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  582. hidden_states = self.out_proj(hidden_states)
  583. return hidden_states
  584. class FlaxBartEncoder(nn.Module):
  585. config: BartConfig
  586. embed_tokens: nn.Embed
  587. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  588. def setup(self):
  589. self.dropout_layer = nn.Dropout(rate=self.config.dropout)
  590. embed_dim = self.config.d_model
  591. self.padding_idx = self.config.pad_token_id
  592. self.max_source_positions = self.config.max_position_embeddings
  593. self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
  594. # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
  595. # and adjust num_embeddings appropriately. Other models don't have this hack
  596. self.offset = 2
  597. self.embed_positions = nn.Embed(
  598. self.config.max_position_embeddings + self.offset,
  599. embed_dim,
  600. embedding_init=jax.nn.initializers.normal(self.config.init_std),
  601. dtype=self.dtype,
  602. )
  603. self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
  604. self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  605. def __call__(
  606. self,
  607. input_ids,
  608. attention_mask,
  609. position_ids,
  610. output_attentions: bool = False,
  611. output_hidden_states: bool = False,
  612. return_dict: bool = True,
  613. deterministic: bool = True,
  614. ):
  615. input_shape = input_ids.shape
  616. input_ids = input_ids.reshape(-1, input_shape[-1])
  617. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  618. embed_pos = self.embed_positions(position_ids + self.offset)
  619. hidden_states = inputs_embeds + embed_pos
  620. hidden_states = self.layernorm_embedding(hidden_states)
  621. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  622. outputs = self.layers(
  623. hidden_states,
  624. attention_mask,
  625. deterministic=deterministic,
  626. output_attentions=output_attentions,
  627. output_hidden_states=output_hidden_states,
  628. return_dict=return_dict,
  629. )
  630. if not return_dict:
  631. return outputs
  632. return FlaxBaseModelOutput(
  633. last_hidden_state=outputs.last_hidden_state,
  634. hidden_states=outputs.hidden_states,
  635. attentions=outputs.attentions,
  636. )
  637. class FlaxBartDecoder(nn.Module):
  638. config: BartConfig
  639. embed_tokens: nn.Embed
  640. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  641. def setup(self):
  642. self.dropout_layer = nn.Dropout(rate=self.config.dropout)
  643. embed_dim = self.config.d_model
  644. self.padding_idx = self.config.pad_token_id
  645. self.max_target_positions = self.config.max_position_embeddings
  646. self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
  647. # Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
  648. # and adjust num_embeddings appropriately. Other models don't have this hack
  649. self.offset = 2
  650. self.embed_positions = nn.Embed(
  651. self.config.max_position_embeddings + self.offset,
  652. embed_dim,
  653. embedding_init=jax.nn.initializers.normal(self.config.init_std),
  654. dtype=self.dtype,
  655. )
  656. self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
  657. self.layernorm_embedding = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)
  658. def __call__(
  659. self,
  660. input_ids,
  661. attention_mask,
  662. position_ids,
  663. encoder_hidden_states: Optional[jnp.ndarray] = None,
  664. encoder_attention_mask: Optional[jnp.ndarray] = None,
  665. init_cache: bool = False,
  666. output_attentions: bool = False,
  667. output_hidden_states: bool = False,
  668. return_dict: bool = True,
  669. deterministic: bool = True,
  670. ):
  671. input_shape = input_ids.shape
  672. input_ids = input_ids.reshape(-1, input_shape[-1])
  673. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  674. # embed positions
  675. positions = self.embed_positions(position_ids + self.offset)
  676. hidden_states = inputs_embeds + positions
  677. hidden_states = self.layernorm_embedding(hidden_states)
  678. hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
  679. outputs = self.layers(
  680. hidden_states,
  681. attention_mask,
  682. encoder_hidden_states,
  683. encoder_attention_mask,
  684. deterministic=deterministic,
  685. init_cache=init_cache,
  686. output_attentions=output_attentions,
  687. output_hidden_states=output_hidden_states,
  688. return_dict=return_dict,
  689. )
  690. if not return_dict:
  691. return outputs
  692. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  693. last_hidden_state=outputs.last_hidden_state,
  694. hidden_states=outputs.hidden_states,
  695. attentions=outputs.attentions,
  696. cross_attentions=outputs.cross_attentions,
  697. )
  698. class FlaxBartModule(nn.Module):
  699. config: BartConfig
  700. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  701. def setup(self):
  702. self.shared = nn.Embed(
  703. self.config.vocab_size,
  704. self.config.d_model,
  705. embedding_init=jax.nn.initializers.normal(self.config.init_std),
  706. dtype=self.dtype,
  707. )
  708. self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
  709. self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
  710. def _get_encoder_module(self):
  711. return self.encoder
  712. def _get_decoder_module(self):
  713. return self.decoder
  714. def __call__(
  715. self,
  716. input_ids,
  717. attention_mask,
  718. decoder_input_ids,
  719. decoder_attention_mask,
  720. position_ids,
  721. decoder_position_ids,
  722. output_attentions: bool = False,
  723. output_hidden_states: bool = False,
  724. return_dict: bool = True,
  725. deterministic: bool = True,
  726. ):
  727. encoder_outputs = self.encoder(
  728. input_ids=input_ids,
  729. attention_mask=attention_mask,
  730. position_ids=position_ids,
  731. output_attentions=output_attentions,
  732. output_hidden_states=output_hidden_states,
  733. return_dict=return_dict,
  734. deterministic=deterministic,
  735. )
  736. decoder_outputs = self.decoder(
  737. input_ids=decoder_input_ids,
  738. attention_mask=decoder_attention_mask,
  739. position_ids=decoder_position_ids,
  740. encoder_hidden_states=encoder_outputs[0],
  741. encoder_attention_mask=attention_mask,
  742. output_attentions=output_attentions,
  743. output_hidden_states=output_hidden_states,
  744. return_dict=return_dict,
  745. deterministic=deterministic,
  746. )
  747. if not return_dict:
  748. return decoder_outputs + encoder_outputs
  749. return FlaxSeq2SeqModelOutput(
  750. last_hidden_state=decoder_outputs.last_hidden_state,
  751. decoder_hidden_states=decoder_outputs.hidden_states,
  752. decoder_attentions=decoder_outputs.attentions,
  753. cross_attentions=decoder_outputs.cross_attentions,
  754. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  755. encoder_hidden_states=encoder_outputs.hidden_states,
  756. encoder_attentions=encoder_outputs.attentions,
  757. )
  758. class FlaxBartPreTrainedModel(FlaxPreTrainedModel):
  759. config_class = BartConfig
  760. base_model_prefix: str = "model"
  761. module_class: nn.Module = None
  762. def __init__(
  763. self,
  764. config: BartConfig,
  765. input_shape: Tuple[int] = (1, 1),
  766. seed: int = 0,
  767. dtype: jnp.dtype = jnp.float32,
  768. _do_init: bool = True,
  769. **kwargs,
  770. ):
  771. module = self.module_class(config=config, dtype=dtype, **kwargs)
  772. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  773. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  774. # init input tensors
  775. input_ids = jnp.zeros(input_shape, dtype="i4")
  776. # make sure initialization pass will work for FlaxBartForSequenceClassificationModule
  777. input_ids = input_ids.at[(..., -1)].set(self.config.eos_token_id)
  778. attention_mask = jnp.ones_like(input_ids)
  779. decoder_input_ids = input_ids
  780. decoder_attention_mask = jnp.ones_like(input_ids)
  781. batch_size, sequence_length = input_ids.shape
  782. position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  783. decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  784. params_rng, dropout_rng = jax.random.split(rng)
  785. rngs = {"params": params_rng, "dropout": dropout_rng}
  786. random_params = self.module.init(
  787. rngs,
  788. input_ids,
  789. attention_mask,
  790. decoder_input_ids,
  791. decoder_attention_mask,
  792. position_ids,
  793. decoder_position_ids,
  794. )["params"]
  795. if params is not None:
  796. random_params = flatten_dict(unfreeze(random_params))
  797. params = flatten_dict(unfreeze(params))
  798. for missing_key in self._missing_keys:
  799. params[missing_key] = random_params[missing_key]
  800. self._missing_keys = set()
  801. return freeze(unflatten_dict(params))
  802. else:
  803. return random_params
  804. def init_cache(self, batch_size, max_length, encoder_outputs):
  805. r"""
  806. Args:
  807. batch_size (`int`):
  808. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  809. max_length (`int`):
  810. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  811. cache.
  812. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
  813. `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
  814. `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
  815. is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
  816. cross-attention of the decoder.
  817. """
  818. # init input variables to retrieve cache
  819. decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  820. decoder_attention_mask = jnp.ones_like(decoder_input_ids)
  821. decoder_position_ids = jnp.broadcast_to(
  822. jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape
  823. )
  824. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
  825. decoder_module = module._get_decoder_module()
  826. return decoder_module(
  827. decoder_input_ids,
  828. decoder_attention_mask,
  829. decoder_position_ids,
  830. **kwargs,
  831. )
  832. init_variables = self.module.init(
  833. jax.random.PRNGKey(0),
  834. decoder_input_ids=decoder_input_ids,
  835. decoder_attention_mask=decoder_attention_mask,
  836. decoder_position_ids=decoder_position_ids,
  837. encoder_hidden_states=encoder_outputs[0],
  838. init_cache=True,
  839. method=_decoder_forward, # we only need to call the decoder to init the cache
  840. )
  841. return unfreeze(init_variables["cache"])
  842. @add_start_docstrings(BART_ENCODE_INPUTS_DOCSTRING)
  843. @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=BartConfig)
  844. def encode(
  845. self,
  846. input_ids: jnp.ndarray,
  847. attention_mask: Optional[jnp.ndarray] = None,
  848. position_ids: Optional[jnp.ndarray] = None,
  849. output_attentions: Optional[bool] = None,
  850. output_hidden_states: Optional[bool] = None,
  851. return_dict: Optional[bool] = None,
  852. train: bool = False,
  853. params: dict = None,
  854. dropout_rng: PRNGKey = None,
  855. ):
  856. r"""
  857. Returns:
  858. Example:
  859. ```python
  860. >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
  861. >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  862. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  863. >>> text = "My friends are cool but they eat too many carbs."
  864. >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
  865. >>> encoder_outputs = model.encode(**inputs)
  866. ```"""
  867. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  868. output_hidden_states = (
  869. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  870. )
  871. return_dict = return_dict if return_dict is not None else self.config.return_dict
  872. if attention_mask is None:
  873. attention_mask = jnp.ones_like(input_ids)
  874. if position_ids is None:
  875. batch_size, sequence_length = input_ids.shape
  876. position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  877. # Handle any PRNG if needed
  878. rngs = {}
  879. if dropout_rng is not None:
  880. rngs["dropout"] = dropout_rng
  881. def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
  882. encode_module = module._get_encoder_module()
  883. return encode_module(input_ids, attention_mask, position_ids, **kwargs)
  884. return self.module.apply(
  885. {"params": params or self.params},
  886. input_ids=jnp.array(input_ids, dtype="i4"),
  887. attention_mask=jnp.array(attention_mask, dtype="i4"),
  888. position_ids=jnp.array(position_ids, dtype="i4"),
  889. output_attentions=output_attentions,
  890. output_hidden_states=output_hidden_states,
  891. return_dict=return_dict,
  892. deterministic=not train,
  893. rngs=rngs,
  894. method=_encoder_forward,
  895. )
  896. @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
  897. @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=BartConfig)
  898. def decode(
  899. self,
  900. decoder_input_ids,
  901. encoder_outputs,
  902. encoder_attention_mask: Optional[jnp.ndarray] = None,
  903. decoder_attention_mask: Optional[jnp.ndarray] = None,
  904. decoder_position_ids: Optional[jnp.ndarray] = None,
  905. past_key_values: dict = None,
  906. output_attentions: Optional[bool] = None,
  907. output_hidden_states: Optional[bool] = None,
  908. return_dict: Optional[bool] = None,
  909. train: bool = False,
  910. params: dict = None,
  911. dropout_rng: PRNGKey = None,
  912. ):
  913. r"""
  914. Returns:
  915. Example:
  916. ```python
  917. >>> import jax.numpy as jnp
  918. >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
  919. >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  920. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  921. >>> text = "My friends are cool but they eat too many carbs."
  922. >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
  923. >>> encoder_outputs = model.encode(**inputs)
  924. >>> decoder_start_token_id = model.config.decoder_start_token_id
  925. >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
  926. >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
  927. >>> last_decoder_hidden_states = outputs.last_hidden_state
  928. ```"""
  929. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  930. output_hidden_states = (
  931. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  932. )
  933. return_dict = return_dict if return_dict is not None else self.config.return_dict
  934. encoder_hidden_states = encoder_outputs[0]
  935. if encoder_attention_mask is None:
  936. batch_size, sequence_length = encoder_hidden_states.shape[:2]
  937. encoder_attention_mask = jnp.ones((batch_size, sequence_length))
  938. batch_size, sequence_length = decoder_input_ids.shape
  939. if decoder_attention_mask is None:
  940. decoder_attention_mask = jnp.ones((batch_size, sequence_length))
  941. if decoder_position_ids is None:
  942. if past_key_values is not None:
  943. raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
  944. decoder_position_ids = jnp.broadcast_to(
  945. jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
  946. )
  947. # Handle any PRNG if needed
  948. rngs = {}
  949. if dropout_rng is not None:
  950. rngs["dropout"] = dropout_rng
  951. inputs = {"params": params or self.params}
  952. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
  953. # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
  954. # it can be changed by FlaxBartAttention module
  955. if past_key_values:
  956. inputs["cache"] = past_key_values
  957. mutable = ["cache"]
  958. else:
  959. mutable = False
  960. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
  961. decoder_module = module._get_decoder_module()
  962. return decoder_module(
  963. decoder_input_ids,
  964. decoder_attention_mask,
  965. decoder_position_ids,
  966. **kwargs,
  967. )
  968. outputs = self.module.apply(
  969. inputs,
  970. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  971. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  972. decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
  973. encoder_hidden_states=encoder_hidden_states,
  974. encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
  975. output_attentions=output_attentions,
  976. output_hidden_states=output_hidden_states,
  977. return_dict=return_dict,
  978. deterministic=not train,
  979. rngs=rngs,
  980. mutable=mutable,
  981. method=_decoder_forward,
  982. )
  983. # add updated cache to model output
  984. if past_key_values is not None and return_dict:
  985. outputs, past = outputs
  986. outputs["past_key_values"] = unfreeze(past["cache"])
  987. return outputs
  988. elif past_key_values is not None and not return_dict:
  989. outputs, past = outputs
  990. outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
  991. return outputs
  992. @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
  993. def __call__(
  994. self,
  995. input_ids: jnp.ndarray,
  996. attention_mask: Optional[jnp.ndarray] = None,
  997. decoder_input_ids: Optional[jnp.ndarray] = None,
  998. decoder_attention_mask: Optional[jnp.ndarray] = None,
  999. position_ids: Optional[jnp.ndarray] = None,
  1000. decoder_position_ids: Optional[jnp.ndarray] = None,
  1001. output_attentions: Optional[bool] = None,
  1002. output_hidden_states: Optional[bool] = None,
  1003. return_dict: Optional[bool] = None,
  1004. train: bool = False,
  1005. params: dict = None,
  1006. dropout_rng: PRNGKey = None,
  1007. ):
  1008. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1009. output_hidden_states = (
  1010. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1011. )
  1012. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1013. # prepare encoder inputs
  1014. if attention_mask is None:
  1015. attention_mask = jnp.ones_like(input_ids)
  1016. if position_ids is None:
  1017. batch_size, sequence_length = input_ids.shape
  1018. position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  1019. # prepare decoder inputs
  1020. if decoder_input_ids is None:
  1021. decoder_input_ids = shift_tokens_right(
  1022. input_ids, self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id
  1023. )
  1024. if decoder_attention_mask is None:
  1025. decoder_attention_mask = jnp.ones_like(decoder_input_ids)
  1026. if decoder_position_ids is None:
  1027. batch_size, sequence_length = decoder_input_ids.shape
  1028. decoder_position_ids = jnp.broadcast_to(
  1029. jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
  1030. )
  1031. # Handle any PRNG if needed
  1032. rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
  1033. return self.module.apply(
  1034. {"params": params or self.params},
  1035. input_ids=jnp.array(input_ids, dtype="i4"),
  1036. attention_mask=jnp.array(attention_mask, dtype="i4"),
  1037. position_ids=jnp.array(position_ids, dtype="i4"),
  1038. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  1039. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  1040. decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
  1041. output_attentions=output_attentions,
  1042. output_hidden_states=output_hidden_states,
  1043. return_dict=return_dict,
  1044. deterministic=not train,
  1045. rngs=rngs,
  1046. )
  1047. @add_start_docstrings(
  1048. "The bare Bart Model transformer outputting raw hidden-states without any specific head on top.",
  1049. BART_START_DOCSTRING,
  1050. )
  1051. class FlaxBartModel(FlaxBartPreTrainedModel):
  1052. config: BartConfig
  1053. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  1054. module_class = FlaxBartModule
  1055. append_call_sample_docstring(FlaxBartModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
  1056. class FlaxBartForConditionalGenerationModule(nn.Module):
  1057. config: BartConfig
  1058. dtype: jnp.dtype = jnp.float32
  1059. bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros
  1060. def setup(self):
  1061. self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
  1062. self.lm_head = nn.Dense(
  1063. self.model.shared.num_embeddings,
  1064. use_bias=False,
  1065. dtype=self.dtype,
  1066. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  1067. )
  1068. self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.model.shared.num_embeddings))
  1069. def _get_encoder_module(self):
  1070. return self.model.encoder
  1071. def _get_decoder_module(self):
  1072. return self.model.decoder
  1073. def __call__(
  1074. self,
  1075. input_ids,
  1076. attention_mask,
  1077. decoder_input_ids,
  1078. decoder_attention_mask,
  1079. position_ids,
  1080. decoder_position_ids,
  1081. output_attentions: bool = False,
  1082. output_hidden_states: bool = False,
  1083. return_dict: bool = True,
  1084. deterministic: bool = True,
  1085. ):
  1086. outputs = self.model(
  1087. input_ids=input_ids,
  1088. attention_mask=attention_mask,
  1089. decoder_input_ids=decoder_input_ids,
  1090. decoder_attention_mask=decoder_attention_mask,
  1091. position_ids=position_ids,
  1092. decoder_position_ids=decoder_position_ids,
  1093. output_attentions=output_attentions,
  1094. output_hidden_states=output_hidden_states,
  1095. return_dict=return_dict,
  1096. deterministic=deterministic,
  1097. )
  1098. hidden_states = outputs[0]
  1099. if self.config.tie_word_embeddings:
  1100. shared_embedding = self.model.variables["params"]["shared"]["embedding"]
  1101. lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
  1102. else:
  1103. lm_logits = self.lm_head(hidden_states)
  1104. lm_logits += jax.lax.stop_gradient(self.final_logits_bias.astype(self.dtype))
  1105. if not return_dict:
  1106. output = (lm_logits,) + outputs[1:]
  1107. return output
  1108. return FlaxSeq2SeqLMOutput(
  1109. logits=lm_logits,
  1110. decoder_hidden_states=outputs.decoder_hidden_states,
  1111. decoder_attentions=outputs.decoder_attentions,
  1112. cross_attentions=outputs.cross_attentions,
  1113. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1114. encoder_hidden_states=outputs.encoder_hidden_states,
  1115. encoder_attentions=outputs.encoder_attentions,
  1116. )
  1117. @add_start_docstrings(
  1118. "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
  1119. )
  1120. class FlaxBartForConditionalGeneration(FlaxBartPreTrainedModel):
  1121. module_class = FlaxBartForConditionalGenerationModule
  1122. dtype: jnp.dtype = jnp.float32
  1123. @add_start_docstrings(BART_DECODE_INPUTS_DOCSTRING)
  1124. @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=BartConfig)
  1125. def decode(
  1126. self,
  1127. decoder_input_ids,
  1128. encoder_outputs,
  1129. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1130. decoder_attention_mask: Optional[jnp.ndarray] = None,
  1131. decoder_position_ids: Optional[jnp.ndarray] = None,
  1132. past_key_values: dict = None,
  1133. output_attentions: Optional[bool] = None,
  1134. output_hidden_states: Optional[bool] = None,
  1135. return_dict: Optional[bool] = None,
  1136. train: bool = False,
  1137. params: dict = None,
  1138. dropout_rng: PRNGKey = None,
  1139. ):
  1140. r"""
  1141. Returns:
  1142. Example:
  1143. ```python
  1144. >>> import jax.numpy as jnp
  1145. >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
  1146. >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  1147. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  1148. >>> text = "My friends are cool but they eat too many carbs."
  1149. >>> inputs = tokenizer(text, max_length=1024, return_tensors="jax")
  1150. >>> encoder_outputs = model.encode(**inputs)
  1151. >>> decoder_start_token_id = model.config.decoder_start_token_id
  1152. >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
  1153. >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
  1154. >>> logits = outputs.logits
  1155. ```"""
  1156. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1157. output_hidden_states = (
  1158. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1159. )
  1160. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1161. encoder_hidden_states = encoder_outputs[0]
  1162. if encoder_attention_mask is None:
  1163. batch_size, sequence_length = encoder_hidden_states.shape[:2]
  1164. encoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1165. batch_size, sequence_length = decoder_input_ids.shape
  1166. if decoder_attention_mask is None:
  1167. decoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1168. if decoder_position_ids is None:
  1169. if past_key_values is not None:
  1170. raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
  1171. decoder_position_ids = jnp.broadcast_to(
  1172. jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
  1173. )
  1174. # Handle any PRNG if needed
  1175. rngs = {}
  1176. if dropout_rng is not None:
  1177. rngs["dropout"] = dropout_rng
  1178. inputs = {"params": params or self.params}
  1179. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
  1180. # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
  1181. # it can be changed by FlaxBartAttention module
  1182. if past_key_values:
  1183. inputs["cache"] = past_key_values
  1184. mutable = ["cache"]
  1185. else:
  1186. mutable = False
  1187. def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs):
  1188. decoder_module = module._get_decoder_module()
  1189. outputs = decoder_module(
  1190. decoder_input_ids,
  1191. decoder_attention_mask,
  1192. decoder_position_ids,
  1193. **kwargs,
  1194. )
  1195. hidden_states = outputs[0]
  1196. if self.config.tie_word_embeddings:
  1197. shared_embedding = module.model.variables["params"]["shared"]["embedding"]
  1198. lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
  1199. else:
  1200. lm_logits = module.lm_head(hidden_states)
  1201. lm_logits += module.final_logits_bias.astype(self.dtype)
  1202. return lm_logits, outputs
  1203. outputs = self.module.apply(
  1204. inputs,
  1205. decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
  1206. decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
  1207. decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
  1208. encoder_hidden_states=encoder_hidden_states,
  1209. encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
  1210. output_attentions=output_attentions,
  1211. output_hidden_states=output_hidden_states,
  1212. return_dict=return_dict,
  1213. deterministic=not train,
  1214. rngs=rngs,
  1215. mutable=mutable,
  1216. method=_decoder_forward,
  1217. )
  1218. if past_key_values is None:
  1219. lm_logits, decoder_outputs = outputs
  1220. else:
  1221. (lm_logits, decoder_outputs), past = outputs
  1222. if return_dict:
  1223. outputs = FlaxCausalLMOutputWithCrossAttentions(
  1224. logits=lm_logits,
  1225. hidden_states=decoder_outputs.hidden_states,
  1226. attentions=decoder_outputs.attentions,
  1227. cross_attentions=decoder_outputs.cross_attentions,
  1228. )
  1229. else:
  1230. outputs = (lm_logits,) + decoder_outputs[1:]
  1231. # add updated cache to model output
  1232. if past_key_values is not None and return_dict:
  1233. outputs["past_key_values"] = unfreeze(past["cache"])
  1234. return outputs
  1235. elif past_key_values is not None and not return_dict:
  1236. outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
  1237. return outputs
  1238. def prepare_inputs_for_generation(
  1239. self,
  1240. decoder_input_ids,
  1241. max_length,
  1242. attention_mask: Optional[jax.Array] = None,
  1243. decoder_attention_mask: Optional[jax.Array] = None,
  1244. encoder_outputs=None,
  1245. **kwargs,
  1246. ):
  1247. # initializing the cache
  1248. batch_size, seq_length = decoder_input_ids.shape
  1249. past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
  1250. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  1251. # But since the decoder uses a causal mask, those positions are masked anyways.
  1252. # Thus we can create a single static attention_mask here, which is more efficient for compilation
  1253. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  1254. if decoder_attention_mask is not None:
  1255. position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
  1256. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
  1257. else:
  1258. position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
  1259. return {
  1260. "past_key_values": past_key_values,
  1261. "encoder_outputs": encoder_outputs,
  1262. "encoder_attention_mask": attention_mask,
  1263. "decoder_attention_mask": extended_attention_mask,
  1264. "decoder_position_ids": position_ids,
  1265. }
  1266. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  1267. model_kwargs["past_key_values"] = model_outputs.past_key_values
  1268. model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
  1269. return model_kwargs
  1270. FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING = """
  1271. Returns:
  1272. Summarization example:
  1273. ```python
  1274. >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
  1275. >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
  1276. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
  1277. >>> ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
  1278. >>> inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors="np")
  1279. >>> # Generate Summary
  1280. >>> summary_ids = model.generate(inputs["input_ids"]).sequences
  1281. >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False))
  1282. ```
  1283. Mask filling example:
  1284. ```python
  1285. >>> import jax
  1286. >>> from transformers import AutoTokenizer, FlaxBartForConditionalGeneration
  1287. >>> model = FlaxBartForConditionalGeneration.from_pretrained("facebook/bart-large")
  1288. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large")
  1289. >>> TXT = "My friends are <mask> but they eat too many carbs."
  1290. >>> input_ids = tokenizer([TXT], return_tensors="jax")["input_ids"]
  1291. >>> logits = model(input_ids).logits
  1292. >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero()[0].item()
  1293. >>> probs = jax.nn.softmax(logits[0, masked_index], axis=0)
  1294. >>> values, predictions = jax.lax.top_k(probs, k=1)
  1295. >>> tokenizer.decode(predictions).split()
  1296. ```
  1297. """
  1298. overwrite_call_docstring(
  1299. FlaxBartForConditionalGeneration, BART_INPUTS_DOCSTRING + FLAX_BART_CONDITIONAL_GENERATION_DOCSTRING
  1300. )
  1301. append_replace_return_docstrings(
  1302. FlaxBartForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC
  1303. )
  1304. class FlaxBartForSequenceClassificationModule(nn.Module):
  1305. config: BartConfig
  1306. dtype: jnp.dtype = jnp.float32
  1307. num_labels: Optional[int] = None
  1308. def setup(self):
  1309. self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
  1310. self.classification_head = FlaxBartClassificationHead(
  1311. config=self.config,
  1312. inner_dim=self.config.d_model,
  1313. num_classes=self.num_labels if self.num_labels is not None else self.config.num_labels,
  1314. pooler_dropout=self.config.classifier_dropout,
  1315. )
  1316. def _get_encoder_module(self):
  1317. return self.model.encoder
  1318. def _get_decoder_module(self):
  1319. return self.model.decoder
  1320. def __call__(
  1321. self,
  1322. input_ids,
  1323. attention_mask,
  1324. decoder_input_ids,
  1325. decoder_attention_mask,
  1326. position_ids,
  1327. decoder_position_ids,
  1328. output_attentions: bool = False,
  1329. output_hidden_states: bool = False,
  1330. return_dict: bool = True,
  1331. deterministic: bool = True,
  1332. ):
  1333. outputs = self.model(
  1334. input_ids=input_ids,
  1335. attention_mask=attention_mask,
  1336. decoder_input_ids=decoder_input_ids,
  1337. decoder_attention_mask=decoder_attention_mask,
  1338. position_ids=position_ids,
  1339. decoder_position_ids=decoder_position_ids,
  1340. output_attentions=output_attentions,
  1341. output_hidden_states=output_hidden_states,
  1342. return_dict=return_dict,
  1343. deterministic=deterministic,
  1344. )
  1345. hidden_states = outputs[0] # last hidden state
  1346. eos_mask = jnp.where(input_ids == self.config.eos_token_id, 1, 0)
  1347. # The first condition is necessary to overcome jax._src.errors.ConcretizationTypeError during JIT compilation
  1348. if not isinstance(eos_mask, jax.interpreters.partial_eval.DynamicJaxprTracer):
  1349. if len(jnp.unique(eos_mask.sum(1))) > 1:
  1350. raise ValueError("All examples must have the same number of <eos> tokens.")
  1351. if any(eos_mask.sum(1) == 0):
  1352. raise ValueError("There are missing <eos> tokens in input_ids")
  1353. # Ensure to keep 1 only for the last <eos> token for each example
  1354. eos_mask_noised = eos_mask + jnp.arange(eos_mask.shape[1]) * 1e-6
  1355. eos_mask = jnp.where(eos_mask_noised == eos_mask_noised.max(1).reshape(-1, 1), 1, 0)
  1356. sentence_representation = jnp.einsum("ijk, ij -> ijk", hidden_states, eos_mask).sum(1)
  1357. logits = self.classification_head(sentence_representation, deterministic=deterministic)
  1358. if not return_dict:
  1359. output = (logits,) + outputs[1:]
  1360. return output
  1361. return FlaxSeq2SeqSequenceClassifierOutput(
  1362. logits=logits,
  1363. decoder_hidden_states=outputs.decoder_hidden_states,
  1364. decoder_attentions=outputs.decoder_attentions,
  1365. cross_attentions=outputs.cross_attentions,
  1366. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1367. encoder_hidden_states=outputs.encoder_hidden_states,
  1368. encoder_attentions=outputs.encoder_attentions,
  1369. )
  1370. @add_start_docstrings(
  1371. """
  1372. Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1373. tasks.
  1374. """,
  1375. BART_START_DOCSTRING,
  1376. )
  1377. class FlaxBartForSequenceClassification(FlaxBartPreTrainedModel):
  1378. module_class = FlaxBartForSequenceClassificationModule
  1379. dtype = jnp.float32
  1380. append_call_sample_docstring(
  1381. FlaxBartForSequenceClassification,
  1382. _CHECKPOINT_FOR_DOC,
  1383. FlaxSeq2SeqSequenceClassifierOutput,
  1384. _CONFIG_FOR_DOC,
  1385. )
  1386. class FlaxBartForQuestionAnsweringModule(nn.Module):
  1387. config: BartConfig
  1388. dtype: jnp.dtype = jnp.float32
  1389. num_labels = 2
  1390. def setup(self):
  1391. self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
  1392. self.qa_outputs = nn.Dense(
  1393. self.num_labels, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std)
  1394. )
  1395. def _get_encoder_module(self):
  1396. return self.model.encoder
  1397. def _get_decoder_module(self):
  1398. return self.model.decoder
  1399. def __call__(
  1400. self,
  1401. input_ids,
  1402. attention_mask,
  1403. decoder_input_ids,
  1404. decoder_attention_mask,
  1405. position_ids,
  1406. decoder_position_ids,
  1407. output_attentions: bool = False,
  1408. output_hidden_states: bool = False,
  1409. return_dict: bool = True,
  1410. deterministic: bool = True,
  1411. ):
  1412. outputs = self.model(
  1413. input_ids=input_ids,
  1414. attention_mask=attention_mask,
  1415. decoder_input_ids=decoder_input_ids,
  1416. decoder_attention_mask=decoder_attention_mask,
  1417. position_ids=position_ids,
  1418. decoder_position_ids=decoder_position_ids,
  1419. output_attentions=output_attentions,
  1420. output_hidden_states=output_hidden_states,
  1421. return_dict=return_dict,
  1422. deterministic=deterministic,
  1423. )
  1424. sequence_output = outputs[0]
  1425. logits = self.qa_outputs(sequence_output)
  1426. start_logits, end_logits = jnp.split(logits, logits.shape[-1], axis=-1)
  1427. start_logits = start_logits.squeeze(-1)
  1428. end_logits = end_logits.squeeze(-1)
  1429. if not return_dict:
  1430. output = (start_logits, end_logits) + outputs[1:]
  1431. return output
  1432. return FlaxSeq2SeqQuestionAnsweringModelOutput(
  1433. start_logits=start_logits,
  1434. end_logits=end_logits,
  1435. decoder_hidden_states=outputs.decoder_hidden_states,
  1436. decoder_attentions=outputs.decoder_attentions,
  1437. cross_attentions=outputs.cross_attentions,
  1438. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1439. encoder_hidden_states=outputs.encoder_hidden_states,
  1440. encoder_attentions=outputs.encoder_attentions,
  1441. )
  1442. @add_start_docstrings(
  1443. """
  1444. BART Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1445. layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1446. """,
  1447. BART_START_DOCSTRING,
  1448. )
  1449. class FlaxBartForQuestionAnswering(FlaxBartPreTrainedModel):
  1450. module_class = FlaxBartForQuestionAnsweringModule
  1451. dtype = jnp.float32
  1452. append_call_sample_docstring(
  1453. FlaxBartForQuestionAnswering,
  1454. _CHECKPOINT_FOR_DOC,
  1455. FlaxSeq2SeqQuestionAnsweringModelOutput,
  1456. _CONFIG_FOR_DOC,
  1457. )
  1458. class FlaxBartDecoderPreTrainedModel(FlaxPreTrainedModel):
  1459. config_class = BartConfig
  1460. base_model_prefix: str = "model"
  1461. module_class: nn.Module = None
  1462. def __init__(
  1463. self,
  1464. config: BartConfig,
  1465. input_shape: Tuple[int] = (1, 1),
  1466. seed: int = 0,
  1467. dtype: jnp.dtype = jnp.float32,
  1468. _do_init: bool = True,
  1469. **kwargs,
  1470. ):
  1471. config.is_decoder = True
  1472. config.is_encoder_decoder = False
  1473. module = self.module_class(config=config, dtype=dtype, **kwargs)
  1474. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  1475. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  1476. # init input tensors
  1477. input_ids = jnp.zeros(input_shape, dtype="i4")
  1478. attention_mask = jnp.ones_like(input_ids)
  1479. batch_size, sequence_length = input_ids.shape
  1480. position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  1481. params_rng, dropout_rng = jax.random.split(rng)
  1482. rngs = {"params": params_rng, "dropout": dropout_rng}
  1483. encoder_hidden_states = jnp.zeros(input_shape + (self.config.d_model,))
  1484. encoder_attention_mask = attention_mask
  1485. module_init_outputs = self.module.init(
  1486. rngs,
  1487. input_ids,
  1488. attention_mask,
  1489. position_ids,
  1490. encoder_hidden_states,
  1491. encoder_attention_mask,
  1492. return_dict=False,
  1493. )
  1494. return module_init_outputs["params"]
  1495. def init_cache(self, batch_size, max_length):
  1496. r"""
  1497. Args:
  1498. batch_size (`int`):
  1499. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  1500. max_length (`int`):
  1501. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  1502. cache.
  1503. """
  1504. # init input variables to retrieve cache
  1505. input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  1506. attention_mask = jnp.ones_like(input_ids, dtype="i4")
  1507. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  1508. init_variables = self.module.init(
  1509. jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
  1510. )
  1511. return unfreeze(init_variables["cache"])
  1512. @add_start_docstrings_to_model_forward(BART_DECODE_INPUTS_DOCSTRING)
  1513. def __call__(
  1514. self,
  1515. input_ids: jnp.ndarray,
  1516. attention_mask: Optional[jnp.ndarray] = None,
  1517. position_ids: Optional[jnp.ndarray] = None,
  1518. encoder_hidden_states: Optional[jnp.ndarray] = None,
  1519. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1520. output_attentions: Optional[bool] = None,
  1521. output_hidden_states: Optional[bool] = None,
  1522. return_dict: Optional[bool] = None,
  1523. train: bool = False,
  1524. params: dict = None,
  1525. past_key_values: dict = None,
  1526. dropout_rng: PRNGKey = None,
  1527. ):
  1528. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1529. output_hidden_states = (
  1530. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1531. )
  1532. return_dict = return_dict if return_dict is not None else self.config.return_dict
  1533. if encoder_hidden_states is not None and encoder_attention_mask is None:
  1534. batch_size, sequence_length = encoder_hidden_states.shape[:2]
  1535. encoder_attention_mask = jnp.ones((batch_size, sequence_length))
  1536. # prepare decoder inputs
  1537. if attention_mask is None:
  1538. attention_mask = jnp.ones_like(input_ids)
  1539. if position_ids is None:
  1540. batch_size, sequence_length = input_ids.shape
  1541. position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
  1542. # Handle any PRNG if needed
  1543. rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
  1544. inputs = {"params": params or self.params}
  1545. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
  1546. # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
  1547. # changed by FlaxBartAttention module
  1548. if past_key_values:
  1549. inputs["cache"] = past_key_values
  1550. mutable = ["cache"]
  1551. else:
  1552. mutable = False
  1553. outputs = self.module.apply(
  1554. inputs,
  1555. input_ids=jnp.array(input_ids, dtype="i4"),
  1556. attention_mask=jnp.array(attention_mask, dtype="i4"),
  1557. position_ids=jnp.array(position_ids, dtype="i4"),
  1558. encoder_hidden_states=encoder_hidden_states,
  1559. encoder_attention_mask=encoder_attention_mask,
  1560. output_attentions=output_attentions,
  1561. output_hidden_states=output_hidden_states,
  1562. return_dict=return_dict,
  1563. deterministic=not train,
  1564. rngs=rngs,
  1565. mutable=mutable,
  1566. )
  1567. # add updated cache to model output
  1568. if past_key_values is not None and return_dict:
  1569. outputs, past_key_values = outputs
  1570. outputs["past_key_values"] = unfreeze(past_key_values["cache"])
  1571. return outputs
  1572. elif past_key_values is not None and not return_dict:
  1573. outputs, past_key_values = outputs
  1574. outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
  1575. return outputs
  1576. class FlaxBartDecoderWrapper(nn.Module):
  1577. """
  1578. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1579. used in combination with the [`EncoderDecoderModel`] framework.
  1580. """
  1581. config: BartConfig
  1582. dtype: jnp.dtype = jnp.float32
  1583. def setup(self):
  1584. embed_dim = self.config.d_model
  1585. embed_tokens = nn.Embed(
  1586. self.config.vocab_size,
  1587. embed_dim,
  1588. embedding_init=jax.nn.initializers.normal(self.config.init_std),
  1589. dtype=self.dtype,
  1590. )
  1591. self.decoder = FlaxBartDecoder(config=self.config, embed_tokens=embed_tokens, dtype=self.dtype)
  1592. def __call__(self, *args, **kwargs):
  1593. return self.decoder(*args, **kwargs)
  1594. class FlaxBartForCausalLMModule(nn.Module):
  1595. config: BartConfig
  1596. dtype: jnp.dtype = jnp.float32
  1597. def setup(self):
  1598. self.model = FlaxBartDecoderWrapper(config=self.config, dtype=self.dtype)
  1599. self.lm_head = nn.Dense(
  1600. self.config.vocab_size,
  1601. use_bias=False,
  1602. dtype=self.dtype,
  1603. kernel_init=jax.nn.initializers.normal(self.config.init_std),
  1604. )
  1605. def __call__(
  1606. self,
  1607. input_ids,
  1608. attention_mask,
  1609. position_ids,
  1610. encoder_hidden_states: Optional[jnp.ndarray] = None,
  1611. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1612. init_cache: bool = False,
  1613. output_attentions: bool = False,
  1614. output_hidden_states: bool = False,
  1615. return_dict: bool = True,
  1616. deterministic: bool = True,
  1617. ):
  1618. outputs = self.model(
  1619. input_ids,
  1620. attention_mask,
  1621. position_ids,
  1622. encoder_hidden_states,
  1623. encoder_attention_mask,
  1624. deterministic=deterministic,
  1625. init_cache=init_cache,
  1626. output_attentions=output_attentions,
  1627. output_hidden_states=output_hidden_states,
  1628. return_dict=return_dict,
  1629. )
  1630. hidden_states = outputs[0]
  1631. if self.config.tie_word_embeddings:
  1632. shared_embedding = self.model.variables["params"]["decoder"]["embed_tokens"]["embedding"]
  1633. lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
  1634. else:
  1635. lm_logits = self.lm_head(hidden_states)
  1636. if not return_dict:
  1637. return (lm_logits,) + outputs[1:]
  1638. return FlaxCausalLMOutputWithCrossAttentions(
  1639. logits=lm_logits,
  1640. hidden_states=outputs.hidden_states,
  1641. attentions=outputs.attentions,
  1642. cross_attentions=outputs.cross_attentions,
  1643. )
  1644. @add_start_docstrings(
  1645. """
  1646. Bart Decoder Model with a language modeling head on top (linear layer with weights tied to the input embeddings)
  1647. e.g for autoregressive tasks.
  1648. """,
  1649. BART_START_DOCSTRING,
  1650. )
  1651. class FlaxBartForCausalLM(FlaxBartDecoderPreTrainedModel):
  1652. module_class = FlaxBartForCausalLMModule
  1653. def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
  1654. # initializing the cache
  1655. batch_size, seq_length = input_ids.shape
  1656. past_key_values = self.init_cache(batch_size, max_length)
  1657. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  1658. # But since the decoder uses a causal mask, those positions are masked anyway.
  1659. # Thus, we can create a single static attention_mask here, which is more efficient for compilation
  1660. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  1661. if attention_mask is not None:
  1662. position_ids = attention_mask.cumsum(axis=-1) - 1
  1663. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
  1664. else:
  1665. position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
  1666. return {
  1667. "past_key_values": past_key_values,
  1668. "attention_mask": extended_attention_mask,
  1669. "position_ids": position_ids,
  1670. }
  1671. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  1672. model_kwargs["past_key_values"] = model_outputs.past_key_values
  1673. model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
  1674. return model_kwargs
  1675. append_call_sample_docstring(
  1676. FlaxBartForCausalLM,
  1677. _CHECKPOINT_FOR_DOC,
  1678. FlaxCausalLMOutputWithCrossAttentions,
  1679. _CONFIG_FOR_DOC,
  1680. )