modeling_flax_electra.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601
  1. # coding=utf-8
  2. # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Callable, Optional, Tuple
  16. import flax
  17. import flax.linen as nn
  18. import jax
  19. import jax.numpy as jnp
  20. import numpy as np
  21. from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
  22. from flax.linen import combine_masks, make_causal_mask
  23. from flax.linen import partitioning as nn_partitioning
  24. from flax.linen.attention import dot_product_attention_weights
  25. from flax.traverse_util import flatten_dict, unflatten_dict
  26. from jax import lax
  27. from ...modeling_flax_outputs import (
  28. FlaxBaseModelOutput,
  29. FlaxBaseModelOutputWithPastAndCrossAttentions,
  30. FlaxCausalLMOutputWithCrossAttentions,
  31. FlaxMaskedLMOutput,
  32. FlaxMultipleChoiceModelOutput,
  33. FlaxQuestionAnsweringModelOutput,
  34. FlaxSequenceClassifierOutput,
  35. FlaxTokenClassifierOutput,
  36. )
  37. from ...modeling_flax_utils import (
  38. ACT2FN,
  39. FlaxPreTrainedModel,
  40. append_call_sample_docstring,
  41. append_replace_return_docstrings,
  42. overwrite_call_docstring,
  43. )
  44. from ...utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  45. from .configuration_electra import ElectraConfig
  46. logger = logging.get_logger(__name__)
  47. _CHECKPOINT_FOR_DOC = "google/electra-small-discriminator"
  48. _CONFIG_FOR_DOC = "ElectraConfig"
  49. remat = nn_partitioning.remat
  50. @flax.struct.dataclass
  51. class FlaxElectraForPreTrainingOutput(ModelOutput):
  52. """
  53. Output type of [`ElectraForPreTraining`].
  54. Args:
  55. logits (`jnp.ndarray` of shape `(batch_size, sequence_length, config.vocab_size)`):
  56. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  57. hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  58. Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape
  59. `(batch_size, sequence_length, hidden_size)`.
  60. Hidden-states of the model at the output of each layer plus the initial embedding outputs.
  61. attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  62. Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  63. sequence_length)`.
  64. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  65. heads.
  66. """
  67. logits: jnp.ndarray = None
  68. hidden_states: Optional[Tuple[jnp.ndarray]] = None
  69. attentions: Optional[Tuple[jnp.ndarray]] = None
  70. ELECTRA_START_DOCSTRING = r"""
  71. This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
  72. library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
  73. This model is also a Flax Linen
  74. [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
  75. regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
  76. Finally, this model supports inherent JAX features such as:
  77. - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
  78. - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
  79. - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
  80. - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
  81. Parameters:
  82. config ([`ElectraConfig`]): Model configuration class with all the parameters of the model.
  83. Initializing with a config file does not load the weights associated with the model, only the
  84. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  85. """
  86. ELECTRA_INPUTS_DOCSTRING = r"""
  87. Args:
  88. input_ids (`numpy.ndarray` of shape `({0})`):
  89. Indices of input sequence tokens in the vocabulary.
  90. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  91. [`PreTrainedTokenizer.__call__`] for details.
  92. [What are input IDs?](../glossary#input-ids)
  93. attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
  94. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  95. - 1 for tokens that are **not masked**,
  96. - 0 for tokens that are **masked**.
  97. [What are attention masks?](../glossary#attention-mask)
  98. token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  99. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  100. 1]`:
  101. - 0 corresponds to a *sentence A* token,
  102. - 1 corresponds to a *sentence B* token.
  103. [What are token type IDs?](../glossary#token-type-ids)
  104. position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
  105. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  106. config.max_position_embeddings - 1]`.
  107. head_mask (`numpy.ndarray` of shape `({0})`, `optional):
  108. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  109. - 1 indicates the head is **not masked**,
  110. - 0 indicates the head is **masked**.
  111. return_dict (`bool`, *optional*):
  112. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  113. """
  114. class FlaxElectraEmbeddings(nn.Module):
  115. """Construct the embeddings from word, position and token_type embeddings."""
  116. config: ElectraConfig
  117. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  118. def setup(self):
  119. self.word_embeddings = nn.Embed(
  120. self.config.vocab_size,
  121. self.config.embedding_size,
  122. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  123. )
  124. self.position_embeddings = nn.Embed(
  125. self.config.max_position_embeddings,
  126. self.config.embedding_size,
  127. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  128. )
  129. self.token_type_embeddings = nn.Embed(
  130. self.config.type_vocab_size,
  131. self.config.embedding_size,
  132. embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
  133. )
  134. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  135. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  136. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings.__call__
  137. def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
  138. # Embed
  139. inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
  140. position_embeds = self.position_embeddings(position_ids.astype("i4"))
  141. token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
  142. # Sum all embeddings
  143. hidden_states = inputs_embeds + token_type_embeddings + position_embeds
  144. # Layer Norm
  145. hidden_states = self.LayerNorm(hidden_states)
  146. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  147. return hidden_states
  148. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Electra
  149. class FlaxElectraSelfAttention(nn.Module):
  150. config: ElectraConfig
  151. causal: bool = False
  152. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  153. def setup(self):
  154. self.head_dim = self.config.hidden_size // self.config.num_attention_heads
  155. if self.config.hidden_size % self.config.num_attention_heads != 0:
  156. raise ValueError(
  157. "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
  158. " : {self.config.num_attention_heads}"
  159. )
  160. self.query = nn.Dense(
  161. self.config.hidden_size,
  162. dtype=self.dtype,
  163. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  164. )
  165. self.key = nn.Dense(
  166. self.config.hidden_size,
  167. dtype=self.dtype,
  168. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  169. )
  170. self.value = nn.Dense(
  171. self.config.hidden_size,
  172. dtype=self.dtype,
  173. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  174. )
  175. if self.causal:
  176. self.causal_mask = make_causal_mask(
  177. jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
  178. )
  179. def _split_heads(self, hidden_states):
  180. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
  181. def _merge_heads(self, hidden_states):
  182. return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
  183. @nn.compact
  184. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
  185. def _concatenate_to_cache(self, key, value, query, attention_mask):
  186. """
  187. This function takes projected key, value states from a single input token and concatenates the states to cached
  188. states from previous steps. This function is slighly adapted from the official Flax repository:
  189. https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
  190. """
  191. # detect if we're initializing by absence of existing cache data.
  192. is_initialized = self.has_variable("cache", "cached_key")
  193. cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
  194. cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
  195. cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
  196. if is_initialized:
  197. *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
  198. # update key, value caches with our new 1d spatial slices
  199. cur_index = cache_index.value
  200. indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
  201. key = lax.dynamic_update_slice(cached_key.value, key, indices)
  202. value = lax.dynamic_update_slice(cached_value.value, value, indices)
  203. cached_key.value = key
  204. cached_value.value = value
  205. num_updated_cache_vectors = query.shape[1]
  206. cache_index.value = cache_index.value + num_updated_cache_vectors
  207. # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
  208. pad_mask = jnp.broadcast_to(
  209. jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
  210. tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
  211. )
  212. attention_mask = combine_masks(pad_mask, attention_mask)
  213. return key, value, attention_mask
  214. def __call__(
  215. self,
  216. hidden_states,
  217. attention_mask,
  218. layer_head_mask,
  219. key_value_states: Optional[jnp.ndarray] = None,
  220. init_cache: bool = False,
  221. deterministic=True,
  222. output_attentions: bool = False,
  223. ):
  224. # if key_value_states are provided this layer is used as a cross-attention layer
  225. # for the decoder
  226. is_cross_attention = key_value_states is not None
  227. batch_size = hidden_states.shape[0]
  228. # get query proj
  229. query_states = self.query(hidden_states)
  230. # get key, value proj
  231. if is_cross_attention:
  232. # cross_attentions
  233. key_states = self.key(key_value_states)
  234. value_states = self.value(key_value_states)
  235. else:
  236. # self_attention
  237. key_states = self.key(hidden_states)
  238. value_states = self.value(hidden_states)
  239. query_states = self._split_heads(query_states)
  240. key_states = self._split_heads(key_states)
  241. value_states = self._split_heads(value_states)
  242. # handle cache prepare causal attention mask
  243. if self.causal:
  244. query_length, key_length = query_states.shape[1], key_states.shape[1]
  245. if self.has_variable("cache", "cached_key"):
  246. mask_shift = self.variables["cache"]["cache_index"]
  247. max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
  248. causal_mask = lax.dynamic_slice(
  249. self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
  250. )
  251. else:
  252. causal_mask = self.causal_mask[:, :, :query_length, :key_length]
  253. causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
  254. # combine masks if needed
  255. if attention_mask is not None and self.causal:
  256. attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
  257. attention_mask = combine_masks(attention_mask, causal_mask)
  258. elif self.causal:
  259. attention_mask = causal_mask
  260. elif attention_mask is not None:
  261. attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
  262. # During fast autoregressive decoding, we feed one position at a time,
  263. # and cache the keys and values step by step.
  264. if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
  265. key_states, value_states, attention_mask = self._concatenate_to_cache(
  266. key_states, value_states, query_states, attention_mask
  267. )
  268. # Convert the boolean attention mask to an attention bias.
  269. if attention_mask is not None:
  270. # attention mask in the form of attention bias
  271. attention_bias = lax.select(
  272. attention_mask > 0,
  273. jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
  274. jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
  275. )
  276. else:
  277. attention_bias = None
  278. dropout_rng = None
  279. if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
  280. dropout_rng = self.make_rng("dropout")
  281. attn_weights = dot_product_attention_weights(
  282. query_states,
  283. key_states,
  284. bias=attention_bias,
  285. dropout_rng=dropout_rng,
  286. dropout_rate=self.config.attention_probs_dropout_prob,
  287. broadcast_dropout=True,
  288. deterministic=deterministic,
  289. dtype=self.dtype,
  290. precision=None,
  291. )
  292. # Mask heads if we want to
  293. if layer_head_mask is not None:
  294. attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
  295. attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
  296. attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
  297. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  298. return outputs
  299. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Electra
  300. class FlaxElectraSelfOutput(nn.Module):
  301. config: ElectraConfig
  302. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  303. def setup(self):
  304. self.dense = nn.Dense(
  305. self.config.hidden_size,
  306. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  307. dtype=self.dtype,
  308. )
  309. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  310. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  311. def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
  312. hidden_states = self.dense(hidden_states)
  313. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  314. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  315. return hidden_states
  316. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Electra
  317. class FlaxElectraAttention(nn.Module):
  318. config: ElectraConfig
  319. causal: bool = False
  320. dtype: jnp.dtype = jnp.float32
  321. def setup(self):
  322. self.self = FlaxElectraSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
  323. self.output = FlaxElectraSelfOutput(self.config, dtype=self.dtype)
  324. def __call__(
  325. self,
  326. hidden_states,
  327. attention_mask,
  328. layer_head_mask,
  329. key_value_states=None,
  330. init_cache=False,
  331. deterministic=True,
  332. output_attentions: bool = False,
  333. ):
  334. # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
  335. # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
  336. # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
  337. attn_outputs = self.self(
  338. hidden_states,
  339. attention_mask,
  340. layer_head_mask=layer_head_mask,
  341. key_value_states=key_value_states,
  342. init_cache=init_cache,
  343. deterministic=deterministic,
  344. output_attentions=output_attentions,
  345. )
  346. attn_output = attn_outputs[0]
  347. hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
  348. outputs = (hidden_states,)
  349. if output_attentions:
  350. outputs += (attn_outputs[1],)
  351. return outputs
  352. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Electra
  353. class FlaxElectraIntermediate(nn.Module):
  354. config: ElectraConfig
  355. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  356. def setup(self):
  357. self.dense = nn.Dense(
  358. self.config.intermediate_size,
  359. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  360. dtype=self.dtype,
  361. )
  362. self.activation = ACT2FN[self.config.hidden_act]
  363. def __call__(self, hidden_states):
  364. hidden_states = self.dense(hidden_states)
  365. hidden_states = self.activation(hidden_states)
  366. return hidden_states
  367. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Electra
  368. class FlaxElectraOutput(nn.Module):
  369. config: ElectraConfig
  370. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  371. def setup(self):
  372. self.dense = nn.Dense(
  373. self.config.hidden_size,
  374. kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
  375. dtype=self.dtype,
  376. )
  377. self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
  378. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  379. def __call__(self, hidden_states, attention_output, deterministic: bool = True):
  380. hidden_states = self.dense(hidden_states)
  381. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  382. hidden_states = self.LayerNorm(hidden_states + attention_output)
  383. return hidden_states
  384. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Electra
  385. class FlaxElectraLayer(nn.Module):
  386. config: ElectraConfig
  387. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  388. def setup(self):
  389. self.attention = FlaxElectraAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
  390. self.intermediate = FlaxElectraIntermediate(self.config, dtype=self.dtype)
  391. self.output = FlaxElectraOutput(self.config, dtype=self.dtype)
  392. if self.config.add_cross_attention:
  393. self.crossattention = FlaxElectraAttention(self.config, causal=False, dtype=self.dtype)
  394. def __call__(
  395. self,
  396. hidden_states,
  397. attention_mask,
  398. layer_head_mask,
  399. encoder_hidden_states: Optional[jnp.ndarray] = None,
  400. encoder_attention_mask: Optional[jnp.ndarray] = None,
  401. init_cache: bool = False,
  402. deterministic: bool = True,
  403. output_attentions: bool = False,
  404. ):
  405. # Self Attention
  406. attention_outputs = self.attention(
  407. hidden_states,
  408. attention_mask,
  409. layer_head_mask=layer_head_mask,
  410. init_cache=init_cache,
  411. deterministic=deterministic,
  412. output_attentions=output_attentions,
  413. )
  414. attention_output = attention_outputs[0]
  415. # Cross-Attention Block
  416. if encoder_hidden_states is not None:
  417. cross_attention_outputs = self.crossattention(
  418. attention_output,
  419. attention_mask=encoder_attention_mask,
  420. layer_head_mask=layer_head_mask,
  421. key_value_states=encoder_hidden_states,
  422. deterministic=deterministic,
  423. output_attentions=output_attentions,
  424. )
  425. attention_output = cross_attention_outputs[0]
  426. hidden_states = self.intermediate(attention_output)
  427. hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
  428. outputs = (hidden_states,)
  429. if output_attentions:
  430. outputs += (attention_outputs[1],)
  431. if encoder_hidden_states is not None:
  432. outputs += (cross_attention_outputs[1],)
  433. return outputs
  434. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Electra
  435. class FlaxElectraLayerCollection(nn.Module):
  436. config: ElectraConfig
  437. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  438. gradient_checkpointing: bool = False
  439. def setup(self):
  440. if self.gradient_checkpointing:
  441. FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7))
  442. self.layers = [
  443. FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
  444. for i in range(self.config.num_hidden_layers)
  445. ]
  446. else:
  447. self.layers = [
  448. FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype)
  449. for i in range(self.config.num_hidden_layers)
  450. ]
  451. def __call__(
  452. self,
  453. hidden_states,
  454. attention_mask,
  455. head_mask,
  456. encoder_hidden_states: Optional[jnp.ndarray] = None,
  457. encoder_attention_mask: Optional[jnp.ndarray] = None,
  458. init_cache: bool = False,
  459. deterministic: bool = True,
  460. output_attentions: bool = False,
  461. output_hidden_states: bool = False,
  462. return_dict: bool = True,
  463. ):
  464. all_attentions = () if output_attentions else None
  465. all_hidden_states = () if output_hidden_states else None
  466. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  467. # Check if head_mask has a correct number of layers specified if desired
  468. if head_mask is not None:
  469. if head_mask.shape[0] != (len(self.layers)):
  470. raise ValueError(
  471. f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
  472. f" {head_mask.shape[0]}."
  473. )
  474. for i, layer in enumerate(self.layers):
  475. if output_hidden_states:
  476. all_hidden_states += (hidden_states,)
  477. layer_outputs = layer(
  478. hidden_states,
  479. attention_mask,
  480. head_mask[i] if head_mask is not None else None,
  481. encoder_hidden_states,
  482. encoder_attention_mask,
  483. init_cache,
  484. deterministic,
  485. output_attentions,
  486. )
  487. hidden_states = layer_outputs[0]
  488. if output_attentions:
  489. all_attentions += (layer_outputs[1],)
  490. if encoder_hidden_states is not None:
  491. all_cross_attentions += (layer_outputs[2],)
  492. if output_hidden_states:
  493. all_hidden_states += (hidden_states,)
  494. outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
  495. if not return_dict:
  496. return tuple(v for v in outputs if v is not None)
  497. return FlaxBaseModelOutputWithPastAndCrossAttentions(
  498. last_hidden_state=hidden_states,
  499. hidden_states=all_hidden_states,
  500. attentions=all_attentions,
  501. cross_attentions=all_cross_attentions,
  502. )
  503. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Electra
  504. class FlaxElectraEncoder(nn.Module):
  505. config: ElectraConfig
  506. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  507. gradient_checkpointing: bool = False
  508. def setup(self):
  509. self.layer = FlaxElectraLayerCollection(
  510. self.config,
  511. dtype=self.dtype,
  512. gradient_checkpointing=self.gradient_checkpointing,
  513. )
  514. def __call__(
  515. self,
  516. hidden_states,
  517. attention_mask,
  518. head_mask,
  519. encoder_hidden_states: Optional[jnp.ndarray] = None,
  520. encoder_attention_mask: Optional[jnp.ndarray] = None,
  521. init_cache: bool = False,
  522. deterministic: bool = True,
  523. output_attentions: bool = False,
  524. output_hidden_states: bool = False,
  525. return_dict: bool = True,
  526. ):
  527. return self.layer(
  528. hidden_states,
  529. attention_mask,
  530. head_mask=head_mask,
  531. encoder_hidden_states=encoder_hidden_states,
  532. encoder_attention_mask=encoder_attention_mask,
  533. init_cache=init_cache,
  534. deterministic=deterministic,
  535. output_attentions=output_attentions,
  536. output_hidden_states=output_hidden_states,
  537. return_dict=return_dict,
  538. )
  539. class FlaxElectraGeneratorPredictions(nn.Module):
  540. config: ElectraConfig
  541. dtype: jnp.dtype = jnp.float32
  542. def setup(self):
  543. self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
  544. self.dense = nn.Dense(self.config.embedding_size, dtype=self.dtype)
  545. def __call__(self, hidden_states):
  546. hidden_states = self.dense(hidden_states)
  547. hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
  548. hidden_states = self.LayerNorm(hidden_states)
  549. return hidden_states
  550. class FlaxElectraDiscriminatorPredictions(nn.Module):
  551. """Prediction module for the discriminator, made up of two dense layers."""
  552. config: ElectraConfig
  553. dtype: jnp.dtype = jnp.float32
  554. def setup(self):
  555. self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
  556. self.dense_prediction = nn.Dense(1, dtype=self.dtype)
  557. def __call__(self, hidden_states):
  558. hidden_states = self.dense(hidden_states)
  559. hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
  560. hidden_states = self.dense_prediction(hidden_states).squeeze(-1)
  561. return hidden_states
  562. class FlaxElectraPreTrainedModel(FlaxPreTrainedModel):
  563. """
  564. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  565. models.
  566. """
  567. config_class = ElectraConfig
  568. base_model_prefix = "electra"
  569. module_class: nn.Module = None
  570. def __init__(
  571. self,
  572. config: ElectraConfig,
  573. input_shape: Tuple = (1, 1),
  574. seed: int = 0,
  575. dtype: jnp.dtype = jnp.float32,
  576. _do_init: bool = True,
  577. gradient_checkpointing: bool = False,
  578. **kwargs,
  579. ):
  580. module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
  581. super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
  582. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
  583. def enable_gradient_checkpointing(self):
  584. self._module = self.module_class(
  585. config=self.config,
  586. dtype=self.dtype,
  587. gradient_checkpointing=True,
  588. )
  589. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights
  590. def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
  591. # init input tensors
  592. input_ids = jnp.zeros(input_shape, dtype="i4")
  593. token_type_ids = jnp.zeros_like(input_ids)
  594. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
  595. attention_mask = jnp.ones_like(input_ids)
  596. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  597. params_rng, dropout_rng = jax.random.split(rng)
  598. rngs = {"params": params_rng, "dropout": dropout_rng}
  599. if self.config.add_cross_attention:
  600. encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
  601. encoder_attention_mask = attention_mask
  602. module_init_outputs = self.module.init(
  603. rngs,
  604. input_ids,
  605. attention_mask,
  606. token_type_ids,
  607. position_ids,
  608. head_mask,
  609. encoder_hidden_states,
  610. encoder_attention_mask,
  611. return_dict=False,
  612. )
  613. else:
  614. module_init_outputs = self.module.init(
  615. rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
  616. )
  617. random_params = module_init_outputs["params"]
  618. if params is not None:
  619. random_params = flatten_dict(unfreeze(random_params))
  620. params = flatten_dict(unfreeze(params))
  621. for missing_key in self._missing_keys:
  622. params[missing_key] = random_params[missing_key]
  623. self._missing_keys = set()
  624. return freeze(unflatten_dict(params))
  625. else:
  626. return random_params
  627. # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
  628. def init_cache(self, batch_size, max_length):
  629. r"""
  630. Args:
  631. batch_size (`int`):
  632. batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
  633. max_length (`int`):
  634. maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
  635. cache.
  636. """
  637. # init input variables to retrieve cache
  638. input_ids = jnp.ones((batch_size, max_length), dtype="i4")
  639. attention_mask = jnp.ones_like(input_ids, dtype="i4")
  640. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  641. init_variables = self.module.init(
  642. jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
  643. )
  644. return unfreeze(init_variables["cache"])
  645. @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  646. def __call__(
  647. self,
  648. input_ids,
  649. attention_mask=None,
  650. token_type_ids=None,
  651. position_ids=None,
  652. head_mask=None,
  653. encoder_hidden_states=None,
  654. encoder_attention_mask=None,
  655. params: dict = None,
  656. dropout_rng: jax.random.PRNGKey = None,
  657. train: bool = False,
  658. output_attentions: Optional[bool] = None,
  659. output_hidden_states: Optional[bool] = None,
  660. return_dict: Optional[bool] = None,
  661. past_key_values: dict = None,
  662. ):
  663. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  664. output_hidden_states = (
  665. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  666. )
  667. return_dict = return_dict if return_dict is not None else self.config.return_dict
  668. # init input tensors if not passed
  669. if token_type_ids is None:
  670. token_type_ids = jnp.ones_like(input_ids)
  671. if position_ids is None:
  672. position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
  673. if attention_mask is None:
  674. attention_mask = jnp.ones_like(input_ids)
  675. if head_mask is None:
  676. head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
  677. # Handle any PRNG if needed
  678. rngs = {}
  679. if dropout_rng is not None:
  680. rngs["dropout"] = dropout_rng
  681. inputs = {"params": params or self.params}
  682. if self.config.add_cross_attention:
  683. # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
  684. # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
  685. # changed by FlaxElectraAttention module
  686. if past_key_values:
  687. inputs["cache"] = past_key_values
  688. mutable = ["cache"]
  689. else:
  690. mutable = False
  691. outputs = self.module.apply(
  692. inputs,
  693. jnp.array(input_ids, dtype="i4"),
  694. jnp.array(attention_mask, dtype="i4"),
  695. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  696. position_ids=jnp.array(position_ids, dtype="i4"),
  697. head_mask=jnp.array(head_mask, dtype="i4"),
  698. encoder_hidden_states=encoder_hidden_states,
  699. encoder_attention_mask=encoder_attention_mask,
  700. deterministic=not train,
  701. output_attentions=output_attentions,
  702. output_hidden_states=output_hidden_states,
  703. return_dict=return_dict,
  704. rngs=rngs,
  705. mutable=mutable,
  706. )
  707. # add updated cache to model output
  708. if past_key_values is not None and return_dict:
  709. outputs, past_key_values = outputs
  710. outputs["past_key_values"] = unfreeze(past_key_values["cache"])
  711. return outputs
  712. elif past_key_values is not None and not return_dict:
  713. outputs, past_key_values = outputs
  714. outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
  715. else:
  716. outputs = self.module.apply(
  717. inputs,
  718. jnp.array(input_ids, dtype="i4"),
  719. jnp.array(attention_mask, dtype="i4"),
  720. token_type_ids=jnp.array(token_type_ids, dtype="i4"),
  721. position_ids=jnp.array(position_ids, dtype="i4"),
  722. head_mask=jnp.array(head_mask, dtype="i4"),
  723. deterministic=not train,
  724. output_attentions=output_attentions,
  725. output_hidden_states=output_hidden_states,
  726. return_dict=return_dict,
  727. rngs=rngs,
  728. )
  729. return outputs
  730. class FlaxElectraModule(nn.Module):
  731. config: ElectraConfig
  732. dtype: jnp.dtype = jnp.float32 # the dtype of the computation
  733. gradient_checkpointing: bool = False
  734. def setup(self):
  735. self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype)
  736. if self.config.embedding_size != self.config.hidden_size:
  737. self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype)
  738. self.encoder = FlaxElectraEncoder(
  739. self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  740. )
  741. def __call__(
  742. self,
  743. input_ids,
  744. attention_mask,
  745. token_type_ids,
  746. position_ids,
  747. head_mask: Optional[np.ndarray] = None,
  748. encoder_hidden_states: Optional[jnp.ndarray] = None,
  749. encoder_attention_mask: Optional[jnp.ndarray] = None,
  750. init_cache: bool = False,
  751. deterministic: bool = True,
  752. output_attentions: bool = False,
  753. output_hidden_states: bool = False,
  754. return_dict: bool = True,
  755. ):
  756. embeddings = self.embeddings(
  757. input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
  758. )
  759. if hasattr(self, "embeddings_project"):
  760. embeddings = self.embeddings_project(embeddings)
  761. return self.encoder(
  762. embeddings,
  763. attention_mask,
  764. head_mask=head_mask,
  765. deterministic=deterministic,
  766. encoder_hidden_states=encoder_hidden_states,
  767. encoder_attention_mask=encoder_attention_mask,
  768. init_cache=init_cache,
  769. output_attentions=output_attentions,
  770. output_hidden_states=output_hidden_states,
  771. return_dict=return_dict,
  772. )
  773. @add_start_docstrings(
  774. "The bare Electra Model transformer outputting raw hidden-states without any specific head on top.",
  775. ELECTRA_START_DOCSTRING,
  776. )
  777. class FlaxElectraModel(FlaxElectraPreTrainedModel):
  778. module_class = FlaxElectraModule
  779. append_call_sample_docstring(FlaxElectraModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutput, _CONFIG_FOR_DOC)
  780. class FlaxElectraTiedDense(nn.Module):
  781. embedding_size: int
  782. dtype: jnp.dtype = jnp.float32
  783. precision = None
  784. bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
  785. def setup(self):
  786. self.bias = self.param("bias", self.bias_init, (self.embedding_size,))
  787. def __call__(self, x, kernel):
  788. x = jnp.asarray(x, self.dtype)
  789. kernel = jnp.asarray(kernel, self.dtype)
  790. y = lax.dot_general(
  791. x,
  792. kernel,
  793. (((x.ndim - 1,), (0,)), ((), ())),
  794. precision=self.precision,
  795. )
  796. bias = jnp.asarray(self.bias, self.dtype)
  797. return y + bias
  798. class FlaxElectraForMaskedLMModule(nn.Module):
  799. config: ElectraConfig
  800. dtype: jnp.dtype = jnp.float32
  801. gradient_checkpointing: bool = False
  802. def setup(self):
  803. self.electra = FlaxElectraModule(
  804. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  805. )
  806. self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
  807. if self.config.tie_word_embeddings:
  808. self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
  809. else:
  810. self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
  811. def __call__(
  812. self,
  813. input_ids,
  814. attention_mask=None,
  815. token_type_ids=None,
  816. position_ids=None,
  817. head_mask=None,
  818. deterministic: bool = True,
  819. output_attentions: bool = False,
  820. output_hidden_states: bool = False,
  821. return_dict: bool = True,
  822. ):
  823. outputs = self.electra(
  824. input_ids,
  825. attention_mask,
  826. token_type_ids,
  827. position_ids,
  828. head_mask,
  829. deterministic=deterministic,
  830. output_attentions=output_attentions,
  831. output_hidden_states=output_hidden_states,
  832. return_dict=return_dict,
  833. )
  834. hidden_states = outputs[0]
  835. prediction_scores = self.generator_predictions(hidden_states)
  836. if self.config.tie_word_embeddings:
  837. shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  838. prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
  839. else:
  840. prediction_scores = self.generator_lm_head(prediction_scores)
  841. if not return_dict:
  842. return (prediction_scores,) + outputs[1:]
  843. return FlaxMaskedLMOutput(
  844. logits=prediction_scores,
  845. hidden_states=outputs.hidden_states,
  846. attentions=outputs.attentions,
  847. )
  848. @add_start_docstrings("""Electra Model with a `language modeling` head on top.""", ELECTRA_START_DOCSTRING)
  849. class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel):
  850. module_class = FlaxElectraForMaskedLMModule
  851. append_call_sample_docstring(FlaxElectraForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC)
  852. class FlaxElectraForPreTrainingModule(nn.Module):
  853. config: ElectraConfig
  854. dtype: jnp.dtype = jnp.float32
  855. gradient_checkpointing: bool = False
  856. def setup(self):
  857. self.electra = FlaxElectraModule(
  858. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  859. )
  860. self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype)
  861. def __call__(
  862. self,
  863. input_ids,
  864. attention_mask=None,
  865. token_type_ids=None,
  866. position_ids=None,
  867. head_mask=None,
  868. deterministic: bool = True,
  869. output_attentions: bool = False,
  870. output_hidden_states: bool = False,
  871. return_dict: bool = True,
  872. ):
  873. # Model
  874. outputs = self.electra(
  875. input_ids,
  876. attention_mask,
  877. token_type_ids,
  878. position_ids,
  879. head_mask,
  880. deterministic=deterministic,
  881. output_attentions=output_attentions,
  882. output_hidden_states=output_hidden_states,
  883. return_dict=return_dict,
  884. )
  885. hidden_states = outputs[0]
  886. logits = self.discriminator_predictions(hidden_states)
  887. if not return_dict:
  888. return (logits,) + outputs[1:]
  889. return FlaxElectraForPreTrainingOutput(
  890. logits=logits,
  891. hidden_states=outputs.hidden_states,
  892. attentions=outputs.attentions,
  893. )
  894. @add_start_docstrings(
  895. """
  896. Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.
  897. It is recommended to load the discriminator checkpoint into that model.
  898. """,
  899. ELECTRA_START_DOCSTRING,
  900. )
  901. class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel):
  902. module_class = FlaxElectraForPreTrainingModule
  903. FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING = """
  904. Returns:
  905. Example:
  906. ```python
  907. >>> from transformers import AutoTokenizer, FlaxElectraForPreTraining
  908. >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
  909. >>> model = FlaxElectraForPreTraining.from_pretrained("google/electra-small-discriminator")
  910. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
  911. >>> outputs = model(**inputs)
  912. >>> prediction_logits = outputs.logits
  913. ```
  914. """
  915. overwrite_call_docstring(
  916. FlaxElectraForPreTraining,
  917. ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length") + FLAX_ELECTRA_FOR_PRETRAINING_DOCSTRING,
  918. )
  919. append_replace_return_docstrings(
  920. FlaxElectraForPreTraining, output_type=FlaxElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
  921. )
  922. class FlaxElectraForTokenClassificationModule(nn.Module):
  923. config: ElectraConfig
  924. dtype: jnp.dtype = jnp.float32
  925. gradient_checkpointing: bool = False
  926. def setup(self):
  927. self.electra = FlaxElectraModule(
  928. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  929. )
  930. classifier_dropout = (
  931. self.config.classifier_dropout
  932. if self.config.classifier_dropout is not None
  933. else self.config.hidden_dropout_prob
  934. )
  935. self.dropout = nn.Dropout(classifier_dropout)
  936. self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
  937. def __call__(
  938. self,
  939. input_ids,
  940. attention_mask=None,
  941. token_type_ids=None,
  942. position_ids=None,
  943. head_mask=None,
  944. deterministic: bool = True,
  945. output_attentions: bool = False,
  946. output_hidden_states: bool = False,
  947. return_dict: bool = True,
  948. ):
  949. # Model
  950. outputs = self.electra(
  951. input_ids,
  952. attention_mask,
  953. token_type_ids,
  954. position_ids,
  955. head_mask,
  956. deterministic=deterministic,
  957. output_attentions=output_attentions,
  958. output_hidden_states=output_hidden_states,
  959. return_dict=return_dict,
  960. )
  961. hidden_states = outputs[0]
  962. hidden_states = self.dropout(hidden_states, deterministic=deterministic)
  963. logits = self.classifier(hidden_states)
  964. if not return_dict:
  965. return (logits,) + outputs[1:]
  966. return FlaxTokenClassifierOutput(
  967. logits=logits,
  968. hidden_states=outputs.hidden_states,
  969. attentions=outputs.attentions,
  970. )
  971. @add_start_docstrings(
  972. """
  973. Electra model with a token classification head on top.
  974. Both the discriminator and generator may be loaded into this model.
  975. """,
  976. ELECTRA_START_DOCSTRING,
  977. )
  978. class FlaxElectraForTokenClassification(FlaxElectraPreTrainedModel):
  979. module_class = FlaxElectraForTokenClassificationModule
  980. append_call_sample_docstring(
  981. FlaxElectraForTokenClassification,
  982. _CHECKPOINT_FOR_DOC,
  983. FlaxTokenClassifierOutput,
  984. _CONFIG_FOR_DOC,
  985. )
  986. def identity(x, **kwargs):
  987. return x
  988. class FlaxElectraSequenceSummary(nn.Module):
  989. r"""
  990. Compute a single vector summary of a sequence hidden states.
  991. Args:
  992. config ([`PretrainedConfig`]):
  993. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  994. config class of your model for the default values it uses):
  995. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  996. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  997. (otherwise to `config.hidden_size`).
  998. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  999. another string or `None` will add no activation.
  1000. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  1001. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  1002. """
  1003. config: ElectraConfig
  1004. dtype: jnp.dtype = jnp.float32
  1005. def setup(self):
  1006. self.summary = identity
  1007. if hasattr(self.config, "summary_use_proj") and self.config.summary_use_proj:
  1008. if (
  1009. hasattr(self.config, "summary_proj_to_labels")
  1010. and self.config.summary_proj_to_labels
  1011. and self.config.num_labels > 0
  1012. ):
  1013. num_classes = self.config.num_labels
  1014. else:
  1015. num_classes = self.config.hidden_size
  1016. self.summary = nn.Dense(num_classes, dtype=self.dtype)
  1017. activation_string = getattr(self.config, "summary_activation", None)
  1018. self.activation = ACT2FN[activation_string] if activation_string else lambda x: x # noqa F407
  1019. self.first_dropout = identity
  1020. if hasattr(self.config, "summary_first_dropout") and self.config.summary_first_dropout > 0:
  1021. self.first_dropout = nn.Dropout(self.config.summary_first_dropout)
  1022. self.last_dropout = identity
  1023. if hasattr(self.config, "summary_last_dropout") and self.config.summary_last_dropout > 0:
  1024. self.last_dropout = nn.Dropout(self.config.summary_last_dropout)
  1025. def __call__(self, hidden_states, cls_index=None, deterministic: bool = True):
  1026. """
  1027. Compute a single vector summary of a sequence hidden states.
  1028. Args:
  1029. hidden_states (`jnp.ndarray` of shape `[batch_size, seq_len, hidden_size]`):
  1030. The hidden states of the last layer.
  1031. cls_index (`jnp.ndarray` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  1032. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  1033. Returns:
  1034. `jnp.ndarray`: The summary of the sequence hidden states.
  1035. """
  1036. # NOTE: this doest "first" type summary always
  1037. output = hidden_states[:, 0]
  1038. output = self.first_dropout(output, deterministic=deterministic)
  1039. output = self.summary(output)
  1040. output = self.activation(output)
  1041. output = self.last_dropout(output, deterministic=deterministic)
  1042. return output
  1043. class FlaxElectraForMultipleChoiceModule(nn.Module):
  1044. config: ElectraConfig
  1045. dtype: jnp.dtype = jnp.float32
  1046. gradient_checkpointing: bool = False
  1047. def setup(self):
  1048. self.electra = FlaxElectraModule(
  1049. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1050. )
  1051. self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype)
  1052. self.classifier = nn.Dense(1, dtype=self.dtype)
  1053. def __call__(
  1054. self,
  1055. input_ids,
  1056. attention_mask=None,
  1057. token_type_ids=None,
  1058. position_ids=None,
  1059. head_mask=None,
  1060. deterministic: bool = True,
  1061. output_attentions: bool = False,
  1062. output_hidden_states: bool = False,
  1063. return_dict: bool = True,
  1064. ):
  1065. num_choices = input_ids.shape[1]
  1066. input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
  1067. attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
  1068. token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
  1069. position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
  1070. # Model
  1071. outputs = self.electra(
  1072. input_ids,
  1073. attention_mask,
  1074. token_type_ids,
  1075. position_ids,
  1076. head_mask,
  1077. deterministic=deterministic,
  1078. output_attentions=output_attentions,
  1079. output_hidden_states=output_hidden_states,
  1080. return_dict=return_dict,
  1081. )
  1082. hidden_states = outputs[0]
  1083. pooled_output = self.sequence_summary(hidden_states, deterministic=deterministic)
  1084. logits = self.classifier(pooled_output)
  1085. reshaped_logits = logits.reshape(-1, num_choices)
  1086. if not return_dict:
  1087. return (reshaped_logits,) + outputs[1:]
  1088. return FlaxMultipleChoiceModelOutput(
  1089. logits=reshaped_logits,
  1090. hidden_states=outputs.hidden_states,
  1091. attentions=outputs.attentions,
  1092. )
  1093. @add_start_docstrings(
  1094. """
  1095. ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
  1096. softmax) e.g. for RocStories/SWAG tasks.
  1097. """,
  1098. ELECTRA_START_DOCSTRING,
  1099. )
  1100. class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel):
  1101. module_class = FlaxElectraForMultipleChoiceModule
  1102. # adapt docstring slightly for FlaxElectraForMultipleChoice
  1103. overwrite_call_docstring(
  1104. FlaxElectraForMultipleChoice, ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  1105. )
  1106. append_call_sample_docstring(
  1107. FlaxElectraForMultipleChoice,
  1108. _CHECKPOINT_FOR_DOC,
  1109. FlaxMultipleChoiceModelOutput,
  1110. _CONFIG_FOR_DOC,
  1111. )
  1112. class FlaxElectraForQuestionAnsweringModule(nn.Module):
  1113. config: ElectraConfig
  1114. dtype: jnp.dtype = jnp.float32
  1115. gradient_checkpointing: bool = False
  1116. def setup(self):
  1117. self.electra = FlaxElectraModule(
  1118. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1119. )
  1120. self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1121. def __call__(
  1122. self,
  1123. input_ids,
  1124. attention_mask=None,
  1125. token_type_ids=None,
  1126. position_ids=None,
  1127. head_mask=None,
  1128. deterministic: bool = True,
  1129. output_attentions: bool = False,
  1130. output_hidden_states: bool = False,
  1131. return_dict: bool = True,
  1132. ):
  1133. # Model
  1134. outputs = self.electra(
  1135. input_ids,
  1136. attention_mask,
  1137. token_type_ids,
  1138. position_ids,
  1139. head_mask,
  1140. deterministic=deterministic,
  1141. output_attentions=output_attentions,
  1142. output_hidden_states=output_hidden_states,
  1143. return_dict=return_dict,
  1144. )
  1145. hidden_states = outputs[0]
  1146. logits = self.qa_outputs(hidden_states)
  1147. start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
  1148. start_logits = start_logits.squeeze(-1)
  1149. end_logits = end_logits.squeeze(-1)
  1150. if not return_dict:
  1151. return (start_logits, end_logits) + outputs[1:]
  1152. return FlaxQuestionAnsweringModelOutput(
  1153. start_logits=start_logits,
  1154. end_logits=end_logits,
  1155. hidden_states=outputs.hidden_states,
  1156. attentions=outputs.attentions,
  1157. )
  1158. @add_start_docstrings(
  1159. """
  1160. ELECTRA Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1161. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1162. """,
  1163. ELECTRA_START_DOCSTRING,
  1164. )
  1165. class FlaxElectraForQuestionAnswering(FlaxElectraPreTrainedModel):
  1166. module_class = FlaxElectraForQuestionAnsweringModule
  1167. append_call_sample_docstring(
  1168. FlaxElectraForQuestionAnswering,
  1169. _CHECKPOINT_FOR_DOC,
  1170. FlaxQuestionAnsweringModelOutput,
  1171. _CONFIG_FOR_DOC,
  1172. )
  1173. class FlaxElectraClassificationHead(nn.Module):
  1174. """Head for sentence-level classification tasks."""
  1175. config: ElectraConfig
  1176. dtype: jnp.dtype = jnp.float32
  1177. def setup(self):
  1178. self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
  1179. classifier_dropout = (
  1180. self.config.classifier_dropout
  1181. if self.config.classifier_dropout is not None
  1182. else self.config.hidden_dropout_prob
  1183. )
  1184. self.dropout = nn.Dropout(classifier_dropout)
  1185. self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
  1186. def __call__(self, hidden_states, deterministic: bool = True):
  1187. x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
  1188. x = self.dropout(x, deterministic=deterministic)
  1189. x = self.dense(x)
  1190. x = ACT2FN["gelu"](x) # although BERT uses tanh here, it seems Electra authors used gelu
  1191. x = self.dropout(x, deterministic=deterministic)
  1192. x = self.out_proj(x)
  1193. return x
  1194. class FlaxElectraForSequenceClassificationModule(nn.Module):
  1195. config: ElectraConfig
  1196. dtype: jnp.dtype = jnp.float32
  1197. gradient_checkpointing: bool = False
  1198. def setup(self):
  1199. self.electra = FlaxElectraModule(
  1200. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1201. )
  1202. self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype)
  1203. def __call__(
  1204. self,
  1205. input_ids,
  1206. attention_mask=None,
  1207. token_type_ids=None,
  1208. position_ids=None,
  1209. head_mask=None,
  1210. deterministic: bool = True,
  1211. output_attentions: bool = False,
  1212. output_hidden_states: bool = False,
  1213. return_dict: bool = True,
  1214. ):
  1215. # Model
  1216. outputs = self.electra(
  1217. input_ids,
  1218. attention_mask,
  1219. token_type_ids,
  1220. position_ids,
  1221. head_mask,
  1222. deterministic=deterministic,
  1223. output_attentions=output_attentions,
  1224. output_hidden_states=output_hidden_states,
  1225. return_dict=return_dict,
  1226. )
  1227. hidden_states = outputs[0]
  1228. logits = self.classifier(hidden_states, deterministic=deterministic)
  1229. if not return_dict:
  1230. return (logits,) + outputs[1:]
  1231. return FlaxSequenceClassifierOutput(
  1232. logits=logits,
  1233. hidden_states=outputs.hidden_states,
  1234. attentions=outputs.attentions,
  1235. )
  1236. @add_start_docstrings(
  1237. """
  1238. Electra Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  1239. pooled output) e.g. for GLUE tasks.
  1240. """,
  1241. ELECTRA_START_DOCSTRING,
  1242. )
  1243. class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel):
  1244. module_class = FlaxElectraForSequenceClassificationModule
  1245. append_call_sample_docstring(
  1246. FlaxElectraForSequenceClassification,
  1247. _CHECKPOINT_FOR_DOC,
  1248. FlaxSequenceClassifierOutput,
  1249. _CONFIG_FOR_DOC,
  1250. )
  1251. class FlaxElectraForCausalLMModule(nn.Module):
  1252. config: ElectraConfig
  1253. dtype: jnp.dtype = jnp.float32
  1254. gradient_checkpointing: bool = False
  1255. def setup(self):
  1256. self.electra = FlaxElectraModule(
  1257. config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing
  1258. )
  1259. self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype)
  1260. if self.config.tie_word_embeddings:
  1261. self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype)
  1262. else:
  1263. self.generator_lm_head = nn.Dense(self.config.vocab_size, dtype=self.dtype)
  1264. def __call__(
  1265. self,
  1266. input_ids,
  1267. attention_mask: Optional[jnp.ndarray] = None,
  1268. token_type_ids: Optional[jnp.ndarray] = None,
  1269. position_ids: Optional[jnp.ndarray] = None,
  1270. head_mask: Optional[jnp.ndarray] = None,
  1271. encoder_hidden_states: Optional[jnp.ndarray] = None,
  1272. encoder_attention_mask: Optional[jnp.ndarray] = None,
  1273. init_cache: bool = False,
  1274. deterministic: bool = True,
  1275. output_attentions: bool = False,
  1276. output_hidden_states: bool = False,
  1277. return_dict: bool = True,
  1278. ):
  1279. outputs = self.electra(
  1280. input_ids,
  1281. attention_mask,
  1282. token_type_ids,
  1283. position_ids,
  1284. head_mask,
  1285. encoder_hidden_states=encoder_hidden_states,
  1286. encoder_attention_mask=encoder_attention_mask,
  1287. init_cache=init_cache,
  1288. deterministic=deterministic,
  1289. output_attentions=output_attentions,
  1290. output_hidden_states=output_hidden_states,
  1291. return_dict=return_dict,
  1292. )
  1293. hidden_states = outputs[0]
  1294. prediction_scores = self.generator_predictions(hidden_states)
  1295. if self.config.tie_word_embeddings:
  1296. shared_embedding = self.electra.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
  1297. prediction_scores = self.generator_lm_head(prediction_scores, shared_embedding.T)
  1298. else:
  1299. prediction_scores = self.generator_lm_head(prediction_scores)
  1300. if not return_dict:
  1301. return (prediction_scores,) + outputs[1:]
  1302. return FlaxCausalLMOutputWithCrossAttentions(
  1303. logits=prediction_scores,
  1304. hidden_states=outputs.hidden_states,
  1305. attentions=outputs.attentions,
  1306. cross_attentions=outputs.cross_attentions,
  1307. )
  1308. @add_start_docstrings(
  1309. """
  1310. Electra Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
  1311. autoregressive tasks.
  1312. """,
  1313. ELECTRA_START_DOCSTRING,
  1314. )
  1315. # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForCausalLM with Bert->Electra
  1316. class FlaxElectraForCausalLM(FlaxElectraPreTrainedModel):
  1317. module_class = FlaxElectraForCausalLMModule
  1318. def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jax.Array] = None):
  1319. # initializing the cache
  1320. batch_size, seq_length = input_ids.shape
  1321. past_key_values = self.init_cache(batch_size, max_length)
  1322. # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
  1323. # But since the decoder uses a causal mask, those positions are masked anyway.
  1324. # Thus, we can create a single static attention_mask here, which is more efficient for compilation
  1325. extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
  1326. if attention_mask is not None:
  1327. position_ids = attention_mask.cumsum(axis=-1) - 1
  1328. extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
  1329. else:
  1330. position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
  1331. return {
  1332. "past_key_values": past_key_values,
  1333. "attention_mask": extended_attention_mask,
  1334. "position_ids": position_ids,
  1335. }
  1336. def update_inputs_for_generation(self, model_outputs, model_kwargs):
  1337. model_kwargs["past_key_values"] = model_outputs.past_key_values
  1338. model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
  1339. return model_kwargs
  1340. append_call_sample_docstring(
  1341. FlaxElectraForCausalLM,
  1342. _CHECKPOINT_FOR_DOC,
  1343. FlaxCausalLMOutputWithCrossAttentions,
  1344. _CONFIG_FOR_DOC,
  1345. )