modeling_flax_roberta.py 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488
  1. # coding=utf-8
  2. # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional, Tuple
  16. import flax.linen as nn
  17. import jax
  18. import jax.numpy as jnp
  19. import numpy as np
  20. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  21. from flax.linen import combine_masks, make_causal_mask
  22. from flax.linen import partitioning as nn_partitioning
  23. from flax.linen.attention import dot_product_attention_weights
  24. from flax.traverse_util import flatten_dict, unflatten_dict
  25. from jax import lax
  26. from ...modeling_flax_outputs import (
  27. FlaxBaseModelOutputWithPastAndCrossAttentions,
  28. FlaxBaseModelOutputWithPooling,
  29. FlaxBaseModelOutputWithPoolingAndCrossAttentions,
  30. FlaxCausalLMOutputWithCrossAttentions,
  31. FlaxMaskedLMOutput,
  32. FlaxMultipleChoiceModelOutput,
  33. FlaxQuestionAnsweringModelOutput,
  34. FlaxSequenceClassifierOutput,
  35. FlaxTokenClassifierOutput,
  36. )
  37. from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring
  38. from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
  39. from .configuration_roberta import RobertaConfig
  40. logger = logging.get_logger(__name__)
  41. _CHECKPOINT_FOR_DOC = "FacebookAI/roberta-base"
  42. _CONFIG_FOR_DOC = "RobertaConfig"
  43. remat = nn_partitioning.remat
  44. def create_position_ids_from_input_ids(input_ids, padding_idx):
  45. """
  46. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  47. are ignored. This is modified from fairseq's `utils.make_positions`.
  48. Args:
  49. input_ids: jnp.ndarray
  50. padding_idx: int
  51. Returns: jnp.ndarray
  52. """
  53. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  54. mask = (input_ids != padding_idx).astype("i4")
  55. if mask.ndim > 2:
  56. mask = mask.reshape((-1, mask.shape[-1]))
  57. incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
  58. incremental_indices = incremental_indices.reshape(input_ids.shape)
  59. else:
  60. incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
  61. return incremental_indices.astype("i4") + padding_idx
  62. ROBERTA_START_DOCSTRING = r"""
  63. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  64. library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
  65. This model is also a
  66. [flax.linen.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html) subclass. Use it as
  67. a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and
  68. behavior.
  69. Finally, this model supports inherent JAX features such as:
  70. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  71. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  72. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  73. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  74. Parameters:
  75. config ([`RobertaConfig`]): Model configuration class with all the parameters of the
  76. model. Initializing with a config file does not load the weights associated with the model, only the
  77. configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
  78. """
  79. ROBERTA_INPUTS_DOCSTRING = r"""
  80. Args:
  81. input_ids (`numpy.ndarray` of shape `({0})`):
  82. Indices of input sequence tokens in the vocabulary.
  83. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  84. [`PreTrainedTokenizer.__call__`] for details.
  85. [What are input IDs?](../glossary#input-ids)
  86. attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
  87. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  88. - 1 for tokens that are **not masked**,
  89. - 0 for tokens that are **masked**.
  90. [What are attention masks?](../glossary#attention-mask)
  91. token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  92. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  93. 1]`:
  94. - 0 corresponds to a *sentence A* token,
  95. - 1 corresponds to a *sentence B* token.
  96. [What are token type IDs?](../glossary#token-type-ids)
  97. position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  98. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  99. config.max_position_embeddings - 1]`.
  100. head_mask (`numpy.ndarray` of shape `({0})`, `optional):
  101. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  102. - 1 indicates the head is **not masked**,
  103. - 0 indicates the head is **masked**.
  104. return_dict (`bool`, *optional*):
  105. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  106. """
  107. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
  108. class FlaxRobertaEmbeddings(nn.Module):
  109. """Construct the embeddings from word, position and token_type embeddings."""
  110. config: RobertaConfig
  111. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  112. def setup(self):
  113. self.word_embeddings = nn.Embed(
  114. self.config.vocab_size,
  115. self.config.hidden_size,
  116. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  117. dtype=self.dtype,
  118. )
  119. self.position_embeddings = nn.Embed(
  120. self.config.max_position_embeddings,
  121. self.config.hidden_size,
  122. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  123. dtype=self.dtype,
  124. )
  125. self.token_type_embeddings = nn.Embed(
  126. self.config.type_vocab_size,
  127. self.config.hidden_size,
  128. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  129. dtype=self.dtype,
  130. )
  131. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  132. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  133. def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
  134. # Embed
  135. inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
  136. position_embeds = self.position_embeddings(position_ids.astype("i4"))
  137. token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
  138. # Sum all embeddings
  139. hidden_states = inputs_embeds + token_type_embeddings + position_embeds
  140. # Layer Norm
  141. hidden_states = self.LayerNorm(hidden_states)
  142. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  143. return hidden_states
  144. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
  145. class FlaxRobertaSelfAttention(nn.Module):
  146. config: RobertaConfig
  147. causal: bool = False
  148. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  149. def setup(self):
  150. self.head_dim = self.config.hidden_size // self.config.num_attention_heads
  151. if self.config.hidden_size % self.config.num_attention_heads != 0:
  152. raise ValueError(
  153. "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
  154. " : {self.config.num_attention_heads}"
  155. )
  156. self.query = nn.Dense(
  157. self.config.hidden_size,
  158. dtype=self.dtype,
  159. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  160. )
  161. self.key = nn.Dense(
  162. self.config.hidden_size,
  163. dtype=self.dtype,
  164. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  165. )
  166. self.value = nn.Dense(
  167. self.config.hidden_size,
  168. dtype=self.dtype,
  169. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  170. )
  171. if self.causal:
  172. self.causal_mask = make_causal_mask(
  173. jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
  174. )
  175. def _split_heads(self, hidden_states):
  176. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
  177. def _merge_heads(self, hidden_states):
  178. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
  179. @nn.compact
  180. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
  181. def _concatenate_to_cache(self, key, value, query, attention_mask):
  182. """
  183. This function takes projected key, value states from a single input token and concatenates the states to cached
  184. states from previous steps. This function is slighly adapted from the official Flax repository:
  185. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  186. """
  187. # detect if we're initializing by absence of existing cache data.
  188. is_initialized = self.has_variable("cache", "cached_key")
  189. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  190. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  191. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  192. if is_initialized:
  193. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  194. # update key, value caches with our new 1d spatial slices
  195. cur_index = cache_index.value
  196. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  197. key = lax.dynamic_update_slice(cached_key.value, key, indices)
  198. value = lax.dynamic_update_slice(cached_value.value, value, indices)
  199. cached_key.value = key
  200. cached_value.value = value
  201. num_updated_cache_vectors = query.shape[1]
  202. cache_index.value = cache_index.value + num_updated_cache_vectors
  203. # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
  204. pad_mask = jnp.broadcast_to(
  205. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  206. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  207. )
  208. attention_mask = combine_masks(pad_mask, attention_mask)
  209. return key, value, attention_mask
  210. def __call__(
  211. self,
  212. hidden_states,
  213. attention_mask,
  214. layer_head_mask,
  215. key_value_states: Optional[jnp.ndarray] = None,
  216. init_cache: bool = False,
  217. deterministic=True,
  218. output_attentions: bool = False,
  219. ):
  220. # if key_value_states are provided this layer is used as a cross-attention layer
  221. # for the decoder
  222. is_cross_attention = key_value_states is not None
  223. batch_size = hidden_states.shape[0]
  224. # get query proj
  225. query_states = self.query(hidden_states)
  226. # get key, value proj
  227. if is_cross_attention:
  228. # cross_attentions
  229. key_states = self.key(key_value_states)
  230. value_states = self.value(key_value_states)
  231. else:
  232. # self_attention
  233. key_states = self.key(hidden_states)
  234. value_states = self.value(hidden_states)
  235. query_states = self._split_heads(query_states)
  236. key_states = self._split_heads(key_states)
  237. value_states = self._split_heads(value_states)
  238. # handle cache prepare causal attention mask
  239. if self.causal:
  240. query_length, key_length = query_states.shape[1], key_states.shape[1]
  241. if self.has_variable("cache", "cached_key"):
  242. mask_shift = self.variables["cache"]["cache_index"]
  243. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  244. causal_mask = lax.dynamic_slice(
  245. self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
  246. )
  247. else:
  248. causal_mask = self.causal_mask[:, :, :query_length, :key_length]
  249. causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
  250. # combine masks if needed
  251. if attention_mask is not None and self.causal:
  252. attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
  253. attention_mask = combine_masks(attention_mask, causal_mask)
  254. elif self.causal:
  255. attention_mask = causal_mask
  256. elif attention_mask is not None:
  257. attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
  258. # During fast autoregressive decoding, we feed one position at a time,
  259. # and cache the keys and values step by step.
  260. if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
  261. key_states, value_states, attention_mask = self._concatenate_to_cache(
  262. key_states, value_states, query_states, attention_mask
  263. )
  264. # Convert the boolean attention mask to an attention bias.
  265. if attention_mask is not None:
  266. # attention mask in the form of attention bias
  267. attention_bias = lax.select(
  268. attention_mask > 0,
  269. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  270. jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
  271. )
  272. else:
  273. attention_bias = None
  274. dropout_rng = None
  275. if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
  276. dropout_rng = self.make_rng("dropout")
  277. attn_weights = dot_product_attention_weights(
  278. query_states,
  279. key_states,
  280. bias=attention_bias,
  281. dropout_rng=dropout_rng,
  282. dropout_rate=self.config.attention_probs_dropout_prob,
  283. broadcast_dropout=True,
  284. deterministic=deterministic,
  285. dtype=self.dtype,
  286. precision=None,
  287. )
  288. # Mask heads if we want to
  289. if layer_head_mask is not None:
  290. attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
  291. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  292. attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
  293. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  294. return outputs
  295. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
  296. class FlaxRobertaSelfOutput(nn.Module):
  297. config: RobertaConfig
  298. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  299. def setup(self):
  300. self.dense = nn.Dense(
  301. self.config.hidden_size,
  302. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  303. dtype=self.dtype,
  304. )
  305. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  306. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  307. def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
  308. hidden_states = self.dense(hidden_states)
  309. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  310. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  311. return hidden_states
  312. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
  313. class FlaxRobertaAttention(nn.Module):
  314. config: RobertaConfig
  315. causal: bool = False
  316. dtype: jnp.dtype = jnp.float32
  317. def setup(self):
  318. self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
  319. self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
  320. def __call__(
  321. self,
  322. hidden_states,
  323. attention_mask,
  324. layer_head_mask,
  325. key_value_states=None,
  326. init_cache=False,
  327. deterministic=True,
  328. output_attentions: bool = False,
  329. ):
  330. # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
  331. # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
  332. # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
  333. attn_outputs = self.self(
  334. hidden_states,
  335. attention_mask,
  336. layer_head_mask=layer_head_mask,
  337. key_value_states=key_value_states,
  338. init_cache=init_cache,
  339. deterministic=deterministic,
  340. output_attentions=output_attentions,
  341. )
  342. attn_output = attn_outputs[0]
  343. hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
  344. outputs = (hidden_states,)
  345. if output_attentions:
  346. outputs += (attn_outputs[1],)
  347. return outputs
  348. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
  349. class FlaxRobertaIntermediate(nn.Module):
  350. config: RobertaConfig
  351. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  352. def setup(self):
  353. self.dense = nn.Dense(
  354. self.config.intermediate_size,
  355. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  356. dtype=self.dtype,
  357. )
  358. self.activation = ACT2FN[self.config.hidden_act]
  359. def __call__(self, hidden_states):
  360. hidden_states = self.dense(hidden_states)
  361. hidden_states = self.activation(hidden_states)
  362. return hidden_states
  363. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
  364. class FlaxRobertaOutput(nn.Module):
  365. config: RobertaConfig
  366. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  367. def setup(self):
  368. self.dense = nn.Dense(
  369. self.config.hidden_size,
  370. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  371. dtype=self.dtype,
  372. )
  373. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  374. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  375. def __call__(self, hidden_states, attention_output, deterministic: bool = True):
  376. hidden_states = self.dense(hidden_states)
  377. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  378. hidden_states = self.LayerNorm(hidden_states + attention_output)
  379. return hidden_states
  380. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
  381. class FlaxRobertaLayer(nn.Module):
  382. config: RobertaConfig
  383. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  384. def setup(self):
  385. self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
  386. self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
  387. self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
  388. if self.config.add_cross_attention:
  389. self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
  390. def __call__(
  391. self,
  392. hidden_states,
  393. attention_mask,
  394. layer_head_mask,
  395. encoder_hidden_states: Optional[jnp.ndarray] = None,
  396. encoder_attention_mask: Optional[jnp.ndarray] = None,
  397. init_cache: bool = False,
  398. deterministic: bool = True,
  399. output_attentions: bool = False,
  400. ):
  401. # Self Attention
  402. attention_outputs = self.attention(
  403. hidden_states,
  404. attention_mask,
  405. layer_head_mask=layer_head_mask,
  406. init_cache=init_cache,
  407. deterministic=deterministic,
  408. output_attentions=output_attentions,
  409. )
  410. attention_output = attention_outputs[0]
  411. # Cross-Attention Block
  412. if encoder_hidden_states is not None:
  413. cross_attention_outputs = self.crossattention(
  414. attention_output,
  415. attention_mask=encoder_attention_mask,
  416. layer_head_mask=layer_head_mask,
  417. key_value_states=encoder_hidden_states,
  418. deterministic=deterministic,
  419. output_attentions=output_attentions,
  420. )
  421. attention_output = cross_attention_outputs[0]
  422. hidden_states = self.intermediate(attention_output)
  423. hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
  424. outputs = (hidden_states,)
  425. if output_attentions:
  426. outputs += (attention_outputs[1],)
  427. if encoder_hidden_states is not None:
  428. outputs += (cross_attention_outputs[1],)
  429. return outputs
  430. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
  431. class FlaxRobertaLayerCollection(nn.Module):
  432. config: RobertaConfig
  433. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  434. gradient_checkpointing: bool = False
  435. def setup(self):
  436. if self.gradient_checkpointing:
  437. FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
  438. self.layers = [
  439. FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
  440. for i in range(self.config.num_hidden_layers)
  441. ]
  442. else:
  443. self.layers = [
  444. FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
  445. for i in range(self.config.num_hidden_layers)
  446. ]
  447. def __call__(
  448. self,
  449. hidden_states,
  450. attention_mask,
  451. head_mask,
  452. encoder_hidden_states: Optional[jnp.ndarray] = None,
  453. encoder_attention_mask: Optional[jnp.ndarray] = None,
  454. init_cache: bool = False,
  455. deterministic: bool = True,
  456. output_attentions: bool = False,
  457. output_hidden_states: bool = False,
  458. return_dict: bool = True,
  459. ):
  460. all_attentions = () if output_attentions else None
  461. all_hidden_states = () if output_hidden_states else None
  462. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  463. # Check if head_mask has a correct number of layers specified if desired
  464. if head_mask is not None:
  465. if head_mask.shape[0] != (len(self.layers)):
  466. raise ValueError(
  467. f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
  468. f" {head_mask.shape[0]}."
  469. )
  470. for i, layer in enumerate(self.layers):
  471. if output_hidden_states:
  472. all_hidden_states += (hidden_states,)
  473. layer_outputs = layer(
  474. hidden_states,
  475. attention_mask,
  476. head_mask[i] if head_mask is not None else None,
  477. encoder_hidden_states,
  478. encoder_attention_mask,
  479. init_cache,
  480. deterministic,
  481. output_attentions,
  482. )
  483. hidden_states = layer_outputs[0]
  484. if output_attentions:
  485. all_attentions += (layer_outputs[1],)
  486. if encoder_hidden_states is not None:
  487. all_cross_attentions += (layer_outputs[2],)
  488. if output_hidden_states:
  489. all_hidden_states += (hidden_states,)
  490. outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
  491. if not return_dict:
  492. return tuple(v for v in outputs if v is not None)
  493. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  494. last_hidden_state=hidden_states,
  495. hidden_states=all_hidden_states,
  496. attentions=all_attentions,
  497. cross_attentions=all_cross_attentions,
  498. )
  499. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
  500. class FlaxRobertaEncoder(nn.Module):
  501. config: RobertaConfig
  502. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  503. gradient_checkpointing: bool = False
  504. def setup(self):
  505. self.layer = FlaxRobertaLayerCollection(
  506. self.config,
  507. dtype=self.dtype,
  508. gradient_checkpointing=self.gradient_checkpointing,
  509. )
  510. def __call__(
  511. self,
  512. hidden_states,
  513. attention_mask,
  514. head_mask,
  515. encoder_hidden_states: Optional[jnp.ndarray] = None,
  516. encoder_attention_mask: Optional[jnp.ndarray] = None,
  517. init_cache: bool = False,
  518. deterministic: bool = True,
  519. output_attentions: bool = False,
  520. output_hidden_states: bool = False,
  521. return_dict: bool = True,
  522. ):
  523. return self.layer(
  524. hidden_states,
  525. attention_mask,
  526. head_mask=head_mask,
  527. encoder_hidden_states=encoder_hidden_states,
  528. encoder_attention_mask=encoder_attention_mask,
  529. init_cache=init_cache,
  530. deterministic=deterministic,
  531. output_attentions=output_attentions,
  532. output_hidden_states=output_hidden_states,
  533. return_dict=return_dict,
  534. )
  535. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
  536. class FlaxRobertaPooler(nn.Module):
  537. config: RobertaConfig
  538. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  539. def setup(self):
  540. self.dense = nn.Dense(
  541. self.config.hidden_size,
  542. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  543. dtype=self.dtype,
  544. )
  545. def __call__(self, hidden_states):
  546. cls_hidden_state = hidden_states[:, 0]
  547. cls_hidden_state = self.dense(cls_hidden_state)
  548. return nn.tanh(cls_hidden_state)
  549. class FlaxRobertaLMHead(nn.Module):
  550. config: RobertaConfig
  551. dtype: jnp.dtype = jnp.float32
  552. bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
  553. def setup(self):
  554. self.dense = nn.Dense(
  555. self.config.hidden_size,
  556. dtype=self.dtype,
  557. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  558. )
  559. self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  560. self.decoder = nn.Dense(
  561. self.config.vocab_size,
  562. dtype=self.dtype,
  563. use_bias=False,
  564. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  565. )
  566. self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
  567. def __call__(self, hidden_states, shared_embedding=None):
  568. hidden_states = self.dense(hidden_states)
  569. hidden_states = ACT2FN["gelu"](hidden_states)
  570. hidden_states = self.layer_norm(hidden_states)
  571. if shared_embedding is not None:
  572. hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
  573. else:
  574. hidden_states = self.decoder(hidden_states)
  575. bias = jnp.asarray(self.bias, self.dtype)
  576. hidden_states += bias
  577. return hidden_states
  578. class FlaxRobertaClassificationHead(nn.Module):
  579. config: RobertaConfig
  580. dtype: jnp.dtype = jnp.float32
  581. def setup(self):
  582. self.dense = nn.Dense(
  583. self.config.hidden_size,
  584. dtype=self.dtype,
  585. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  586. )
  587. classifier_dropout = (
  588. self.config.classifier_dropout
  589. if self.config.classifier_dropout is not None
  590. else self.config.hidden_dropout_prob
  591. )
  592. self.dropout = nn.Dropout(rate=classifier_dropout)
  593. self.out_proj = nn.Dense(
  594. self.config.num_labels,
  595. dtype=self.dtype,
  596. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  597. )
  598. def __call__(self, hidden_states, deterministic=True):
  599. hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
  600. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  601. hidden_states = self.dense(hidden_states)
  602. hidden_states = nn.tanh(hidden_states)
  603. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  604. hidden_states = self.out_proj(hidden_states)
  605. return hidden_states
  606. class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
  607. """
  608. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  609. models.
  610. """
  611. config_class = RobertaConfig
  612. base_model_prefix = "roberta"
  613. module_class: nn.Module = None
  614. def __init__(
  615. self,
  616. config: RobertaConfig,
  617. input_shape: Tuple = (1, 1),
  618. seed: int = 0,
  619. dtype: jnp.dtype = jnp.float32,
  620. _do_init: bool = True,
  621. gradient_checkpointing: bool = False,
  622. **kwargs,
  623. ):
  624. module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
  625. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  626. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
  627. def enable_gradient_checkpointing(self):
  628. self._module = self.module_class(
  629. config=self.config,
  630. dtype=self.dtype,
  631. gradient_checkpointing=True,
  632. )
  633. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  634. # init input tensors
  635. input_ids = jnp.zeros(input_shape, dtype="i4")
  636. token_type_ids = jnp.ones_like(input_ids)
  637. position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
  638. attention_mask = jnp.ones_like(input_ids)
  639. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  640. params_rng, dropout_rng = jax.random.split(rng)
  641. rngs = {"params": params_rng, "dropout": dropout_rng}
  642. if self.config.add_cross_attention:
  643. encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
  644. encoder_attention_mask = attention_mask
  645. module_init_outputs = self.module.init(
  646. rngs,
  647. input_ids,
  648. attention_mask,
  649. token_type_ids,
  650. position_ids,
  651. head_mask,
  652. encoder_hidden_states,
  653. encoder_attention_mask,
  654. return_dict=False,
  655. )
  656. else:
  657. module_init_outputs = self.module.init(
  658. rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
  659. )
  660. random_params = module_init_outputs["params"]
  661. if params is not None:
  662. random_params = flatten_dict(unfreeze(random_params))
  663. params = flatten_dict(unfreeze(params))
  664. for missing_key in self._missing_keys:
  665. params[missing_key] = random_params[missing_key]
  666. self._missing_keys = set()
  667. return freeze(unflatten_dict(params))
  668. else:
  669. return random_params
  670. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
  671. def init_cache(self, batch_size, max_length):
  672. r"""
  673. Args:
  674. batch_size (`int`):
  675. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  676. max_length (`int`):
  677. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  678. cache.
  679. """
  680. # init input variables to retrieve cache
  681. input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  682. attention_mask = jnp.ones_like(input_ids, dtype="i4")
  683. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  684. init_variables = self.module.init(
  685. jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
  686. )
  687. return unfreeze(init_variables["cache"])
  688. @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  689. def __call__(
  690. self,
  691. input_ids,
  692. attention_mask=None,
  693. token_type_ids=None,
  694. position_ids=None,
  695. head_mask=None,
  696. encoder_hidden_states=None,
  697. encoder_attention_mask=None,
  698. params: dict = None,
  699. dropout_rng: jax.random.PRNGKey = None,
  700. train: bool = False,
  701. output_attentions: Optional[bool] = None,
  702. output_hidden_states: Optional[bool] = None,
  703. return_dict: Optional[bool] = None,
  704. past_key_values: dict = None,
  705. ):
  706. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  707. output_hidden_states = (
  708. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  709. )
  710. return_dict = return_dict if return_dict is not None else self.config.return_dict
  711. # init input tensors if not passed
  712. if token_type_ids is None:
  713. token_type_ids = jnp.zeros_like(input_ids)
  714. if position_ids is None:
  715. position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
  716. if attention_mask is None:
  717. attention_mask = jnp.ones_like(input_ids)
  718. if head_mask is None:
  719. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  720. # Handle any PRNG if needed
  721. rngs = {}
  722. if dropout_rng is not None:
  723. rngs["dropout"] = dropout_rng
  724. inputs = {"params": params or self.params}
  725. if self.config.add_cross_attention:
  726. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
  727. # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
  728. # changed by FlaxRobertaAttention module
  729. if past_key_values:
  730. inputs["cache"] = past_key_values
  731. mutable = ["cache"]
  732. else:
  733. mutable = False
  734. outputs = self.module.apply(
  735. inputs,
  736. jnp.array(input_ids, dtype="i4"),
  737. jnp.array(attention_mask, dtype="i4"),
  738. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  739. position_ids=jnp.array(position_ids, dtype="i4"),
  740. head_mask=jnp.array(head_mask, dtype="i4"),
  741. encoder_hidden_states=encoder_hidden_states,
  742. encoder_attention_mask=encoder_attention_mask,
  743. deterministic=not train,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. return_dict=return_dict,
  747. rngs=rngs,
  748. mutable=mutable,
  749. )
  750. # add updated cache to model output
  751. if past_key_values is not None and return_dict:
  752. outputs, past_key_values = outputs
  753. outputs["past_key_values"] = unfreeze(past_key_values["cache"])
  754. return outputs
  755. elif past_key_values is not None and not return_dict:
  756. outputs, past_key_values = outputs
  757. outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
  758. else:
  759. outputs = self.module.apply(
  760. inputs,
  761. jnp.array(input_ids, dtype="i4"),
  762. jnp.array(attention_mask, dtype="i4"),
  763. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  764. position_ids=jnp.array(position_ids, dtype="i4"),
  765. head_mask=jnp.array(head_mask, dtype="i4"),
  766. deterministic=not train,
  767. output_attentions=output_attentions,
  768. output_hidden_states=output_hidden_states,
  769. return_dict=return_dict,
  770. rngs=rngs,
  771. )
  772. return outputs
  773. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
  774. class FlaxRobertaModule(nn.Module):
  775. config: RobertaConfig
  776. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  777. add_pooling_layer: bool = True
  778. gradient_checkpointing: bool = False
  779. def setup(self):
  780. self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
  781. self.encoder = FlaxRobertaEncoder(
  782. self.config,
  783. dtype=self.dtype,
  784. gradient_checkpointing=self.gradient_checkpointing,
  785. )
  786. self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
  787. def __call__(
  788. self,
  789. input_ids,
  790. attention_mask,
  791. token_type_ids: Optional[jnp.ndarray] = None,
  792. position_ids: Optional[jnp.ndarray] = None,
  793. head_mask: Optional[jnp.ndarray] = None,
  794. encoder_hidden_states: Optional[jnp.ndarray] = None,
  795. encoder_attention_mask: Optional[jnp.ndarray] = None,
  796. init_cache: bool = False,
  797. deterministic: bool = True,
  798. output_attentions: bool = False,
  799. output_hidden_states: bool = False,
  800. return_dict: bool = True,
  801. ):
  802. # make sure `token_type_ids` is correctly initialized when not passed
  803. if token_type_ids is None:
  804. token_type_ids = jnp.zeros_like(input_ids)
  805. # make sure `position_ids` is correctly initialized when not passed
  806. if position_ids is None:
  807. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  808. hidden_states = self.embeddings(
  809. input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
  810. )
  811. outputs = self.encoder(
  812. hidden_states,
  813. attention_mask,
  814. head_mask=head_mask,
  815. deterministic=deterministic,
  816. encoder_hidden_states=encoder_hidden_states,
  817. encoder_attention_mask=encoder_attention_mask,
  818. init_cache=init_cache,
  819. output_attentions=output_attentions,
  820. output_hidden_states=output_hidden_states,
  821. return_dict=return_dict,
  822. )
  823. hidden_states = outputs[0]
  824. pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
  825. if not return_dict:
  826. # if pooled is None, don't return it
  827. if pooled is None:
  828. return (hidden_states,) + outputs[1:]
  829. return (hidden_states, pooled) + outputs[1:]
  830. return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
  831. last_hidden_state=hidden_states,
  832. pooler_output=pooled,
  833. hidden_states=outputs.hidden_states,
  834. attentions=outputs.attentions,
  835. cross_attentions=outputs.cross_attentions,
  836. )
  837. @add_start_docstrings(
  838. "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
  839. ROBERTA_START_DOCSTRING,
  840. )
  841. class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
  842. module_class = FlaxRobertaModule
  843. append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
  844. class FlaxRobertaForMaskedLMModule(nn.Module):
  845. config: RobertaConfig
  846. dtype: jnp.dtype = jnp.float32
  847. gradient_checkpointing: bool = False
  848. def setup(self):
  849. self.roberta = FlaxRobertaModule(
  850. config=self.config,
  851. add_pooling_layer=False,
  852. dtype=self.dtype,
  853. gradient_checkpointing=self.gradient_checkpointing,
  854. )
  855. self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
  856. def __call__(
  857. self,
  858. input_ids,
  859. attention_mask,
  860. token_type_ids,
  861. position_ids,
  862. head_mask,
  863. deterministic: bool = True,
  864. output_attentions: bool = False,
  865. output_hidden_states: bool = False,
  866. return_dict: bool = True,
  867. ):
  868. # Model
  869. outputs = self.roberta(
  870. input_ids,
  871. attention_mask,
  872. token_type_ids,
  873. position_ids,
  874. head_mask,
  875. deterministic=deterministic,
  876. output_attentions=output_attentions,
  877. output_hidden_states=output_hidden_states,
  878. return_dict=return_dict,
  879. )
  880. hidden_states = outputs[0]
  881. if self.config.tie_word_embeddings:
  882. shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  883. else:
  884. shared_embedding = None
  885. # Compute the prediction scores
  886. logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
  887. if not return_dict:
  888. return (logits,) + outputs[1:]
  889. return FlaxMaskedLMOutput(
  890. logits=logits,
  891. hidden_states=outputs.hidden_states,
  892. attentions=outputs.attentions,
  893. )
  894. @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
  895. class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
  896. module_class = FlaxRobertaForMaskedLMModule
  897. append_call_sample_docstring(
  898. FlaxRobertaForMaskedLM,
  899. _CHECKPOINT_FOR_DOC,
  900. FlaxBaseModelOutputWithPooling,
  901. _CONFIG_FOR_DOC,
  902. mask="<mask>",
  903. )
  904. class FlaxRobertaForSequenceClassificationModule(nn.Module):
  905. config: RobertaConfig
  906. dtype: jnp.dtype = jnp.float32
  907. gradient_checkpointing: bool = False
  908. def setup(self):
  909. self.roberta = FlaxRobertaModule(
  910. config=self.config,
  911. dtype=self.dtype,
  912. add_pooling_layer=False,
  913. gradient_checkpointing=self.gradient_checkpointing,
  914. )
  915. self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
  916. def __call__(
  917. self,
  918. input_ids,
  919. attention_mask,
  920. token_type_ids,
  921. position_ids,
  922. head_mask,
  923. deterministic: bool = True,
  924. output_attentions: bool = False,
  925. output_hidden_states: bool = False,
  926. return_dict: bool = True,
  927. ):
  928. # Model
  929. outputs = self.roberta(
  930. input_ids,
  931. attention_mask,
  932. token_type_ids,
  933. position_ids,
  934. head_mask,
  935. deterministic=deterministic,
  936. output_attentions=output_attentions,
  937. output_hidden_states=output_hidden_states,
  938. return_dict=return_dict,
  939. )
  940. sequence_output = outputs[0]
  941. logits = self.classifier(sequence_output, deterministic=deterministic)
  942. if not return_dict:
  943. return (logits,) + outputs[1:]
  944. return FlaxSequenceClassifierOutput(
  945. logits=logits,
  946. hidden_states=outputs.hidden_states,
  947. attentions=outputs.attentions,
  948. )
  949. @add_start_docstrings(
  950. """
  951. Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  952. pooled output) e.g. for GLUE tasks.
  953. """,
  954. ROBERTA_START_DOCSTRING,
  955. )
  956. class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
  957. module_class = FlaxRobertaForSequenceClassificationModule
  958. append_call_sample_docstring(
  959. FlaxRobertaForSequenceClassification,
  960. _CHECKPOINT_FOR_DOC,
  961. FlaxSequenceClassifierOutput,
  962. _CONFIG_FOR_DOC,
  963. )
  964. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
  965. class FlaxRobertaForMultipleChoiceModule(nn.Module):
  966. config: RobertaConfig
  967. dtype: jnp.dtype = jnp.float32
  968. gradient_checkpointing: bool = False
  969. def setup(self):
  970. self.roberta = FlaxRobertaModule(
  971. config=self.config,
  972. dtype=self.dtype,
  973. gradient_checkpointing=self.gradient_checkpointing,
  974. )
  975. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  976. self.classifier = nn.Dense(1, dtype=self.dtype)
  977. def __call__(
  978. self,
  979. input_ids,
  980. attention_mask,
  981. token_type_ids,
  982. position_ids,
  983. head_mask,
  984. deterministic: bool = True,
  985. output_attentions: bool = False,
  986. output_hidden_states: bool = False,
  987. return_dict: bool = True,
  988. ):
  989. num_choices = input_ids.shape[1]
  990. input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
  991. attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
  992. token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
  993. position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
  994. # Model
  995. outputs = self.roberta(
  996. input_ids,
  997. attention_mask,
  998. token_type_ids,
  999. position_ids,
  1000. head_mask,
  1001. deterministic=deterministic,
  1002. output_attentions=output_attentions,
  1003. output_hidden_states=output_hidden_states,
  1004. return_dict=return_dict,
  1005. )
  1006. pooled_output = outputs[1]
  1007. pooled_output = self.dropout(pooled_output, deterministic=deterministic)
  1008. logits = self.classifier(pooled_output)
  1009. reshaped_logits = logits.reshape(-1, num_choices)
  1010. if not return_dict:
  1011. return (reshaped_logits,) + outputs[2:]
  1012. return FlaxMultipleChoiceModelOutput(
  1013. logits=reshaped_logits,
  1014. hidden_states=outputs.hidden_states,
  1015. attentions=outputs.attentions,
  1016. )
  1017. @add_start_docstrings(
  1018. """
  1019. Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1020. softmax) e.g. for RocStories/SWAG tasks.
  1021. """,
  1022. ROBERTA_START_DOCSTRING,
  1023. )
  1024. class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
  1025. module_class = FlaxRobertaForMultipleChoiceModule
  1026. overwrite_call_docstring(
  1027. FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  1028. )
  1029. append_call_sample_docstring(
  1030. FlaxRobertaForMultipleChoice,
  1031. _CHECKPOINT_FOR_DOC,
  1032. FlaxMultipleChoiceModelOutput,
  1033. _CONFIG_FOR_DOC,
  1034. )
  1035. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
  1036. class FlaxRobertaForTokenClassificationModule(nn.Module):
  1037. config: RobertaConfig
  1038. dtype: jnp.dtype = jnp.float32
  1039. gradient_checkpointing: bool = False
  1040. def setup(self):
  1041. self.roberta = FlaxRobertaModule(
  1042. config=self.config,
  1043. dtype=self.dtype,
  1044. add_pooling_layer=False,
  1045. gradient_checkpointing=self.gradient_checkpointing,
  1046. )
  1047. classifier_dropout = (
  1048. self.config.classifier_dropout
  1049. if self.config.classifier_dropout is not None
  1050. else self.config.hidden_dropout_prob
  1051. )
  1052. self.dropout = nn.Dropout(rate=classifier_dropout)
  1053. self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1054. def __call__(
  1055. self,
  1056. input_ids,
  1057. attention_mask,
  1058. token_type_ids,
  1059. position_ids,
  1060. head_mask,
  1061. deterministic: bool = True,
  1062. output_attentions: bool = False,
  1063. output_hidden_states: bool = False,
  1064. return_dict: bool = True,
  1065. ):
  1066. # Model
  1067. outputs = self.roberta(
  1068. input_ids,
  1069. attention_mask,
  1070. token_type_ids,
  1071. position_ids,
  1072. head_mask,
  1073. deterministic=deterministic,
  1074. output_attentions=output_attentions,
  1075. output_hidden_states=output_hidden_states,
  1076. return_dict=return_dict,
  1077. )
  1078. hidden_states = outputs[0]
  1079. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  1080. logits = self.classifier(hidden_states)
  1081. if not return_dict:
  1082. return (logits,) + outputs[1:]
  1083. return FlaxTokenClassifierOutput(
  1084. logits=logits,
  1085. hidden_states=outputs.hidden_states,
  1086. attentions=outputs.attentions,
  1087. )
  1088. @add_start_docstrings(
  1089. """
  1090. Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1091. Named-Entity-Recognition (NER) tasks.
  1092. """,
  1093. ROBERTA_START_DOCSTRING,
  1094. )
  1095. class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
  1096. module_class = FlaxRobertaForTokenClassificationModule
  1097. append_call_sample_docstring(
  1098. FlaxRobertaForTokenClassification,
  1099. _CHECKPOINT_FOR_DOC,
  1100. FlaxTokenClassifierOutput,
  1101. _CONFIG_FOR_DOC,
  1102. )
  1103. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
  1104. class FlaxRobertaForQuestionAnsweringModule(nn.Module):
  1105. config: RobertaConfig
  1106. dtype: jnp.dtype = jnp.float32
  1107. gradient_checkpointing: bool = False
  1108. def setup(self):
  1109. self.roberta = FlaxRobertaModule(
  1110. config=self.config,
  1111. dtype=self.dtype,
  1112. add_pooling_layer=False,
  1113. gradient_checkpointing=self.gradient_checkpointing,
  1114. )
  1115. self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1116. def __call__(
  1117. self,
  1118. input_ids,
  1119. attention_mask,
  1120. token_type_ids,
  1121. position_ids,
  1122. head_mask,
  1123. deterministic: bool = True,
  1124. output_attentions: bool = False,
  1125. output_hidden_states: bool = False,
  1126. return_dict: bool = True,
  1127. ):
  1128. # Model
  1129. outputs = self.roberta(
  1130. input_ids,
  1131. attention_mask,
  1132. token_type_ids,
  1133. position_ids,
  1134. head_mask,
  1135. deterministic=deterministic,
  1136. output_attentions=output_attentions,
  1137. output_hidden_states=output_hidden_states,
  1138. return_dict=return_dict,
  1139. )
  1140. hidden_states = outputs[0]
  1141. logits = self.qa_outputs(hidden_states)
  1142. start_logits, end_logits = jnp.split(logits, self.config.num_labels, axis=-1)
  1143. start_logits = start_logits.squeeze(-1)
  1144. end_logits = end_logits.squeeze(-1)
  1145. if not return_dict:
  1146. return (start_logits, end_logits) + outputs[1:]
  1147. return FlaxQuestionAnsweringModelOutput(
  1148. start_logits=start_logits,
  1149. end_logits=end_logits,
  1150. hidden_states=outputs.hidden_states,
  1151. attentions=outputs.attentions,
  1152. )
  1153. @add_start_docstrings(
  1154. """
  1155. Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1156. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1157. """,
  1158. ROBERTA_START_DOCSTRING,
  1159. )
  1160. class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
  1161. module_class = FlaxRobertaForQuestionAnsweringModule
  1162. append_call_sample_docstring(
  1163. FlaxRobertaForQuestionAnswering,
  1164. _CHECKPOINT_FOR_DOC,
  1165. FlaxQuestionAnsweringModelOutput,
  1166. _CONFIG_FOR_DOC,
  1167. )
  1168. class FlaxRobertaForCausalLMModule(nn.Module):
  1169. config: RobertaConfig
  1170. dtype: jnp.dtype = jnp.float32
  1171. gradient_checkpointing: bool = False
  1172. def setup(self):
  1173. self.roberta = FlaxRobertaModule(
  1174. config=self.config,
  1175. add_pooling_layer=False,
  1176. dtype=self.dtype,
  1177. gradient_checkpointing=self.gradient_checkpointing,
  1178. )
  1179. self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
  1180. def __call__(
  1181. self,
  1182. input_ids,
  1183. attention_mask,
  1184. position_ids,
  1185. token_type_ids: Optional[jnp.ndarray] = None,
  1186. head_mask: Optional[jnp.ndarray] = None,
  1187. encoder_hidden_states: Optional[jnp.ndarray] = None,
  1188. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1189. init_cache: bool = False,
  1190. deterministic: bool = True,
  1191. output_attentions: bool = False,
  1192. output_hidden_states: bool = False,
  1193. return_dict: bool = True,
  1194. ):
  1195. # Model
  1196. outputs = self.roberta(
  1197. input_ids,
  1198. attention_mask,
  1199. token_type_ids,
  1200. position_ids,
  1201. head_mask,
  1202. encoder_hidden_states=encoder_hidden_states,
  1203. encoder_attention_mask=encoder_attention_mask,
  1204. init_cache=init_cache,
  1205. deterministic=deterministic,
  1206. output_attentions=output_attentions,
  1207. output_hidden_states=output_hidden_states,
  1208. return_dict=return_dict,
  1209. )
  1210. hidden_states = outputs[0]
  1211. if self.config.tie_word_embeddings:
  1212. shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  1213. else:
  1214. shared_embedding = None
  1215. # Compute the prediction scores
  1216. logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
  1217. if not return_dict:
  1218. return (logits,) + outputs[1:]
  1219. return FlaxCausalLMOutputWithCrossAttentions(
  1220. logits=logits,
  1221. hidden_states=outputs.hidden_states,
  1222. attentions=outputs.attentions,
  1223. cross_attentions=outputs.cross_attentions,
  1224. )
  1225. @add_start_docstrings(
  1226. """
  1227. Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
  1228. autoregressive tasks.
  1229. """,
  1230. ROBERTA_START_DOCSTRING,
  1231. )
  1232. class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
  1233. module_class = FlaxRobertaForCausalLMModule
  1234. def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
  1235. # initializing the cache
  1236. batch_size, seq_length = input_ids.shape
  1237. past_key_values = self.init_cache(batch_size, max_length)
  1238. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  1239. # But since the decoder uses a causal mask, those positions are masked anyway.
  1240. # Thus, we can create a single static attention_mask here, which is more efficient for compilation
  1241. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  1242. if attention_mask is not None:
  1243. position_ids = attention_mask.cumsum(axis=-1) - 1
  1244. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
  1245. else:
  1246. position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
  1247. return {
  1248. "past_key_values": past_key_values,
  1249. "attention_mask": extended_attention_mask,
  1250. "position_ids": position_ids,
  1251. }
  1252. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  1253. model_kwargs["past_key_values"] = model_outputs.past_key_values
  1254. model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
  1255. return model_kwargs
  1256. append_call_sample_docstring(
  1257. FlaxRobertaForCausalLM,
  1258. _CHECKPOINT_FOR_DOC,
  1259. FlaxCausalLMOutputWithCrossAttentions,
  1260. _CONFIG_FOR_DOC,
  1261. )