modeling_squeezebert.py 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087
  1. # coding=utf-8
  2. # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SqueezeBert model."""
  16. import math
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. MaskedLMOutput,
  26. MultipleChoiceModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
  33. from .configuration_squeezebert import SqueezeBertConfig
  34. logger = logging.get_logger(__name__)
  35. _CHECKPOINT_FOR_DOC = "squeezebert/squeezebert-uncased"
  36. _CONFIG_FOR_DOC = "SqueezeBertConfig"
  37. class SqueezeBertEmbeddings(nn.Module):
  38. """Construct the embeddings from word, position and token_type embeddings."""
  39. def __init__(self, config):
  40. super().__init__()
  41. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  42. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  43. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  44. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  45. # any TensorFlow checkpoint file
  46. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  47. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  48. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  49. self.register_buffer(
  50. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  51. )
  52. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  53. if input_ids is not None:
  54. input_shape = input_ids.size()
  55. else:
  56. input_shape = inputs_embeds.size()[:-1]
  57. seq_length = input_shape[1]
  58. if position_ids is None:
  59. position_ids = self.position_ids[:, :seq_length]
  60. if token_type_ids is None:
  61. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  62. if inputs_embeds is None:
  63. inputs_embeds = self.word_embeddings(input_ids)
  64. position_embeddings = self.position_embeddings(position_ids)
  65. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  66. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  67. embeddings = self.LayerNorm(embeddings)
  68. embeddings = self.dropout(embeddings)
  69. return embeddings
  70. class MatMulWrapper(nn.Module):
  71. """
  72. Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
  73. torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
  74. """
  75. def __init__(self):
  76. super().__init__()
  77. def forward(self, mat1, mat2):
  78. """
  79. :param inputs: two torch tensors :return: matmul of these tensors
  80. Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
  81. mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
  82. """
  83. return torch.matmul(mat1, mat2)
  84. class SqueezeBertLayerNorm(nn.LayerNorm):
  85. """
  86. This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.
  87. N = batch C = channels W = sequence length
  88. """
  89. def __init__(self, hidden_size, eps=1e-12):
  90. nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps) # instantiates self.{weight, bias, eps}
  91. def forward(self, x):
  92. x = x.permute(0, 2, 1)
  93. x = nn.LayerNorm.forward(self, x)
  94. return x.permute(0, 2, 1)
  95. class ConvDropoutLayerNorm(nn.Module):
  96. """
  97. ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
  98. """
  99. def __init__(self, cin, cout, groups, dropout_prob):
  100. super().__init__()
  101. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  102. self.layernorm = SqueezeBertLayerNorm(cout)
  103. self.dropout = nn.Dropout(dropout_prob)
  104. def forward(self, hidden_states, input_tensor):
  105. x = self.conv1d(hidden_states)
  106. x = self.dropout(x)
  107. x = x + input_tensor
  108. x = self.layernorm(x)
  109. return x
  110. class ConvActivation(nn.Module):
  111. """
  112. ConvActivation: Conv, Activation
  113. """
  114. def __init__(self, cin, cout, groups, act):
  115. super().__init__()
  116. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  117. self.act = ACT2FN[act]
  118. def forward(self, x):
  119. output = self.conv1d(x)
  120. return self.act(output)
  121. class SqueezeBertSelfAttention(nn.Module):
  122. def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1):
  123. """
  124. config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
  125. groups = number of groups to use in conv1d layers
  126. """
  127. super().__init__()
  128. if cin % config.num_attention_heads != 0:
  129. raise ValueError(
  130. f"cin ({cin}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  131. )
  132. self.num_attention_heads = config.num_attention_heads
  133. self.attention_head_size = int(cin / config.num_attention_heads)
  134. self.all_head_size = self.num_attention_heads * self.attention_head_size
  135. self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups)
  136. self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups)
  137. self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups)
  138. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  139. self.softmax = nn.Softmax(dim=-1)
  140. self.matmul_qk = MatMulWrapper()
  141. self.matmul_qkv = MatMulWrapper()
  142. def transpose_for_scores(self, x):
  143. """
  144. - input: [N, C, W]
  145. - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
  146. """
  147. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  148. x = x.view(*new_x_shape)
  149. return x.permute(0, 1, 3, 2) # [N, C1, C2, W] --> [N, C1, W, C2]
  150. def transpose_key_for_scores(self, x):
  151. """
  152. - input: [N, C, W]
  153. - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
  154. """
  155. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  156. x = x.view(*new_x_shape)
  157. # no `permute` needed
  158. return x
  159. def transpose_output(self, x):
  160. """
  161. - input: [N, C1, W, C2]
  162. - output: [N, C, W]
  163. """
  164. x = x.permute(0, 1, 3, 2).contiguous() # [N, C1, C2, W]
  165. new_x_shape = (x.size()[0], self.all_head_size, x.size()[3]) # [N, C, W]
  166. x = x.view(*new_x_shape)
  167. return x
  168. def forward(self, hidden_states, attention_mask, output_attentions):
  169. """
  170. expects hidden_states in [N, C, W] data layout.
  171. The attention_mask data layout is [N, W], and it does not need to be transposed.
  172. """
  173. mixed_query_layer = self.query(hidden_states)
  174. mixed_key_layer = self.key(hidden_states)
  175. mixed_value_layer = self.value(hidden_states)
  176. query_layer = self.transpose_for_scores(mixed_query_layer)
  177. key_layer = self.transpose_key_for_scores(mixed_key_layer)
  178. value_layer = self.transpose_for_scores(mixed_value_layer)
  179. # Take the dot product between "query" and "key" to get the raw attention scores.
  180. attention_score = self.matmul_qk(query_layer, key_layer)
  181. attention_score = attention_score / math.sqrt(self.attention_head_size)
  182. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  183. attention_score = attention_score + attention_mask
  184. # Normalize the attention scores to probabilities.
  185. attention_probs = self.softmax(attention_score)
  186. # This is actually dropping out entire tokens to attend to, which might
  187. # seem a bit unusual, but is taken from the original Transformer paper.
  188. attention_probs = self.dropout(attention_probs)
  189. context_layer = self.matmul_qkv(attention_probs, value_layer)
  190. context_layer = self.transpose_output(context_layer)
  191. result = {"context_layer": context_layer}
  192. if output_attentions:
  193. result["attention_score"] = attention_score
  194. return result
  195. class SqueezeBertModule(nn.Module):
  196. def __init__(self, config):
  197. """
  198. - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
  199. the module
  200. - intermediate_size = output chans for intermediate layer
  201. - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
  202. allow different groups for different layers)
  203. """
  204. super().__init__()
  205. c0 = config.hidden_size
  206. c1 = config.hidden_size
  207. c2 = config.intermediate_size
  208. c3 = config.hidden_size
  209. self.attention = SqueezeBertSelfAttention(
  210. config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups
  211. )
  212. self.post_attention = ConvDropoutLayerNorm(
  213. cin=c0, cout=c1, groups=config.post_attention_groups, dropout_prob=config.hidden_dropout_prob
  214. )
  215. self.intermediate = ConvActivation(cin=c1, cout=c2, groups=config.intermediate_groups, act=config.hidden_act)
  216. self.output = ConvDropoutLayerNorm(
  217. cin=c2, cout=c3, groups=config.output_groups, dropout_prob=config.hidden_dropout_prob
  218. )
  219. def forward(self, hidden_states, attention_mask, output_attentions):
  220. att = self.attention(hidden_states, attention_mask, output_attentions)
  221. attention_output = att["context_layer"]
  222. post_attention_output = self.post_attention(attention_output, hidden_states)
  223. intermediate_output = self.intermediate(post_attention_output)
  224. layer_output = self.output(intermediate_output, post_attention_output)
  225. output_dict = {"feature_map": layer_output}
  226. if output_attentions:
  227. output_dict["attention_score"] = att["attention_score"]
  228. return output_dict
  229. class SqueezeBertEncoder(nn.Module):
  230. def __init__(self, config):
  231. super().__init__()
  232. assert config.embedding_size == config.hidden_size, (
  233. "If you want embedding_size != intermediate hidden_size, "
  234. "please insert a Conv1d layer to adjust the number of channels "
  235. "before the first SqueezeBertModule."
  236. )
  237. self.layers = nn.ModuleList(SqueezeBertModule(config) for _ in range(config.num_hidden_layers))
  238. def forward(
  239. self,
  240. hidden_states,
  241. attention_mask=None,
  242. head_mask=None,
  243. output_attentions=False,
  244. output_hidden_states=False,
  245. return_dict=True,
  246. ):
  247. if head_mask is None:
  248. head_mask_is_all_none = True
  249. elif head_mask.count(None) == len(head_mask):
  250. head_mask_is_all_none = True
  251. else:
  252. head_mask_is_all_none = False
  253. assert head_mask_is_all_none is True, "head_mask is not yet supported in the SqueezeBert implementation."
  254. # [batch_size, sequence_length, hidden_size] --> [batch_size, hidden_size, sequence_length]
  255. hidden_states = hidden_states.permute(0, 2, 1)
  256. all_hidden_states = () if output_hidden_states else None
  257. all_attentions = () if output_attentions else None
  258. for layer in self.layers:
  259. if output_hidden_states:
  260. hidden_states = hidden_states.permute(0, 2, 1)
  261. all_hidden_states += (hidden_states,)
  262. hidden_states = hidden_states.permute(0, 2, 1)
  263. layer_output = layer.forward(hidden_states, attention_mask, output_attentions)
  264. hidden_states = layer_output["feature_map"]
  265. if output_attentions:
  266. all_attentions += (layer_output["attention_score"],)
  267. # [batch_size, hidden_size, sequence_length] --> [batch_size, sequence_length, hidden_size]
  268. hidden_states = hidden_states.permute(0, 2, 1)
  269. if output_hidden_states:
  270. all_hidden_states += (hidden_states,)
  271. if not return_dict:
  272. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  273. return BaseModelOutput(
  274. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  275. )
  276. class SqueezeBertPooler(nn.Module):
  277. def __init__(self, config):
  278. super().__init__()
  279. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  280. self.activation = nn.Tanh()
  281. def forward(self, hidden_states):
  282. # We "pool" the model by simply taking the hidden state corresponding
  283. # to the first token.
  284. first_token_tensor = hidden_states[:, 0]
  285. pooled_output = self.dense(first_token_tensor)
  286. pooled_output = self.activation(pooled_output)
  287. return pooled_output
  288. class SqueezeBertPredictionHeadTransform(nn.Module):
  289. def __init__(self, config):
  290. super().__init__()
  291. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  292. if isinstance(config.hidden_act, str):
  293. self.transform_act_fn = ACT2FN[config.hidden_act]
  294. else:
  295. self.transform_act_fn = config.hidden_act
  296. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  297. def forward(self, hidden_states):
  298. hidden_states = self.dense(hidden_states)
  299. hidden_states = self.transform_act_fn(hidden_states)
  300. hidden_states = self.LayerNorm(hidden_states)
  301. return hidden_states
  302. class SqueezeBertLMPredictionHead(nn.Module):
  303. def __init__(self, config):
  304. super().__init__()
  305. self.transform = SqueezeBertPredictionHeadTransform(config)
  306. # The output weights are the same as the input embeddings, but there is
  307. # an output-only bias for each token.
  308. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  309. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  310. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  311. self.decoder.bias = self.bias
  312. def _tie_weights(self) -> None:
  313. self.decoder.bias = self.bias
  314. def forward(self, hidden_states):
  315. hidden_states = self.transform(hidden_states)
  316. hidden_states = self.decoder(hidden_states)
  317. return hidden_states
  318. class SqueezeBertOnlyMLMHead(nn.Module):
  319. def __init__(self, config):
  320. super().__init__()
  321. self.predictions = SqueezeBertLMPredictionHead(config)
  322. def forward(self, sequence_output):
  323. prediction_scores = self.predictions(sequence_output)
  324. return prediction_scores
  325. class SqueezeBertPreTrainedModel(PreTrainedModel):
  326. """
  327. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  328. models.
  329. """
  330. config_class = SqueezeBertConfig
  331. base_model_prefix = "transformer"
  332. def _init_weights(self, module):
  333. """Initialize the weights"""
  334. if isinstance(module, (nn.Linear, nn.Conv1d)):
  335. # Slightly different from the TF version which uses truncated_normal for initialization
  336. # cf https://github.com/pytorch/pytorch/pull/5617
  337. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  338. if module.bias is not None:
  339. module.bias.data.zero_()
  340. elif isinstance(module, nn.Embedding):
  341. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  342. if module.padding_idx is not None:
  343. module.weight.data[module.padding_idx].zero_()
  344. elif isinstance(module, SqueezeBertLayerNorm):
  345. module.bias.data.zero_()
  346. module.weight.data.fill_(1.0)
  347. SQUEEZEBERT_START_DOCSTRING = r"""
  348. The SqueezeBERT model was proposed in [SqueezeBERT: What can computer vision teach NLP about efficient neural
  349. networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W.
  350. Keutzer
  351. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  352. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  353. etc.)
  354. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  355. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  356. and behavior.
  357. For best results finetuning SqueezeBERT on text classification tasks, it is recommended to use the
  358. *squeezebert/squeezebert-mnli-headless* checkpoint as a starting point.
  359. Parameters:
  360. config ([`SqueezeBertConfig`]): Model configuration class with all the parameters of the model.
  361. Initializing with a config file does not load the weights associated with the model, only the
  362. configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  363. Hierarchy:
  364. ```
  365. Internal class hierarchy:
  366. SqueezeBertModel
  367. SqueezeBertEncoder
  368. SqueezeBertModule
  369. SqueezeBertSelfAttention
  370. ConvActivation
  371. ConvDropoutLayerNorm
  372. ```
  373. Data layouts:
  374. ```
  375. Input data is in [batch, sequence_length, hidden_size] format.
  376. Data inside the encoder is in [batch, hidden_size, sequence_length] format. But, if `output_hidden_states == True`, the data from inside the encoder is returned in [batch, sequence_length, hidden_size] format.
  377. The final output of the encoder is in [batch, sequence_length, hidden_size] format.
  378. ```
  379. """
  380. SQUEEZEBERT_INPUTS_DOCSTRING = r"""
  381. Args:
  382. input_ids (`torch.LongTensor` of shape `({0})`):
  383. Indices of input sequence tokens in the vocabulary.
  384. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  385. [`PreTrainedTokenizer.__call__`] for details.
  386. [What are input IDs?](../glossary#input-ids)
  387. attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
  388. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  389. - 1 for tokens that are **not masked**,
  390. - 0 for tokens that are **masked**.
  391. [What are attention masks?](../glossary#attention-mask)
  392. token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  393. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  394. 1]`:
  395. - 0 corresponds to a *sentence A* token,
  396. - 1 corresponds to a *sentence B* token.
  397. [What are token type IDs?](../glossary#token-type-ids)
  398. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  399. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  400. config.max_position_embeddings - 1]`.
  401. [What are position IDs?](../glossary#position-ids)
  402. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  403. Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
  404. - 1 indicates the head is **not masked**,
  405. - 0 indicates the head is **masked**.
  406. inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
  407. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  408. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  409. model's internal embedding lookup matrix.
  410. output_attentions (`bool`, *optional*):
  411. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  412. tensors for more detail.
  413. output_hidden_states (`bool`, *optional*):
  414. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  415. more detail.
  416. return_dict (`bool`, *optional*):
  417. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  418. """
  419. @add_start_docstrings(
  420. "The bare SqueezeBERT Model transformer outputting raw hidden-states without any specific head on top.",
  421. SQUEEZEBERT_START_DOCSTRING,
  422. )
  423. class SqueezeBertModel(SqueezeBertPreTrainedModel):
  424. def __init__(self, config):
  425. super().__init__(config)
  426. self.embeddings = SqueezeBertEmbeddings(config)
  427. self.encoder = SqueezeBertEncoder(config)
  428. self.pooler = SqueezeBertPooler(config)
  429. # Initialize weights and apply final processing
  430. self.post_init()
  431. def get_input_embeddings(self):
  432. return self.embeddings.word_embeddings
  433. def set_input_embeddings(self, new_embeddings):
  434. self.embeddings.word_embeddings = new_embeddings
  435. def _prune_heads(self, heads_to_prune):
  436. """
  437. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  438. class PreTrainedModel
  439. """
  440. for layer, heads in heads_to_prune.items():
  441. self.encoder.layer[layer].attention.prune_heads(heads)
  442. @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  443. @add_code_sample_docstrings(
  444. checkpoint=_CHECKPOINT_FOR_DOC,
  445. output_type=BaseModelOutputWithPooling,
  446. config_class=_CONFIG_FOR_DOC,
  447. )
  448. def forward(
  449. self,
  450. input_ids: Optional[torch.Tensor] = None,
  451. attention_mask: Optional[torch.Tensor] = None,
  452. token_type_ids: Optional[torch.Tensor] = None,
  453. position_ids: Optional[torch.Tensor] = None,
  454. head_mask: Optional[torch.Tensor] = None,
  455. inputs_embeds: Optional[torch.FloatTensor] = None,
  456. output_attentions: Optional[bool] = None,
  457. output_hidden_states: Optional[bool] = None,
  458. return_dict: Optional[bool] = None,
  459. ) -> Union[Tuple, BaseModelOutputWithPooling]:
  460. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  461. output_hidden_states = (
  462. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  463. )
  464. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  465. if input_ids is not None and inputs_embeds is not None:
  466. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  467. elif input_ids is not None:
  468. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  469. input_shape = input_ids.size()
  470. elif inputs_embeds is not None:
  471. input_shape = inputs_embeds.size()[:-1]
  472. else:
  473. raise ValueError("You have to specify either input_ids or inputs_embeds")
  474. device = input_ids.device if input_ids is not None else inputs_embeds.device
  475. if attention_mask is None:
  476. attention_mask = torch.ones(input_shape, device=device)
  477. if token_type_ids is None:
  478. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  479. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  480. # Prepare head mask if needed
  481. # 1.0 in head_mask indicate we keep the head
  482. # attention_probs has shape bsz x n_heads x N x N
  483. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  484. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  485. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  486. embedding_output = self.embeddings(
  487. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  488. )
  489. encoder_outputs = self.encoder(
  490. hidden_states=embedding_output,
  491. attention_mask=extended_attention_mask,
  492. head_mask=head_mask,
  493. output_attentions=output_attentions,
  494. output_hidden_states=output_hidden_states,
  495. return_dict=return_dict,
  496. )
  497. sequence_output = encoder_outputs[0]
  498. pooled_output = self.pooler(sequence_output)
  499. if not return_dict:
  500. return (sequence_output, pooled_output) + encoder_outputs[1:]
  501. return BaseModelOutputWithPooling(
  502. last_hidden_state=sequence_output,
  503. pooler_output=pooled_output,
  504. hidden_states=encoder_outputs.hidden_states,
  505. attentions=encoder_outputs.attentions,
  506. )
  507. @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING)
  508. class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
  509. _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
  510. def __init__(self, config):
  511. super().__init__(config)
  512. self.transformer = SqueezeBertModel(config)
  513. self.cls = SqueezeBertOnlyMLMHead(config)
  514. # Initialize weights and apply final processing
  515. self.post_init()
  516. def get_output_embeddings(self):
  517. return self.cls.predictions.decoder
  518. def set_output_embeddings(self, new_embeddings):
  519. self.cls.predictions.decoder = new_embeddings
  520. self.cls.predictions.bias = new_embeddings.bias
  521. @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  522. @add_code_sample_docstrings(
  523. checkpoint=_CHECKPOINT_FOR_DOC,
  524. output_type=MaskedLMOutput,
  525. config_class=_CONFIG_FOR_DOC,
  526. )
  527. def forward(
  528. self,
  529. input_ids: Optional[torch.Tensor] = None,
  530. attention_mask: Optional[torch.Tensor] = None,
  531. token_type_ids: Optional[torch.Tensor] = None,
  532. position_ids: Optional[torch.Tensor] = None,
  533. head_mask: Optional[torch.Tensor] = None,
  534. inputs_embeds: Optional[torch.Tensor] = None,
  535. labels: Optional[torch.Tensor] = None,
  536. output_attentions: Optional[bool] = None,
  537. output_hidden_states: Optional[bool] = None,
  538. return_dict: Optional[bool] = None,
  539. ) -> Union[Tuple, MaskedLMOutput]:
  540. r"""
  541. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  542. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  543. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  544. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  545. """
  546. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  547. outputs = self.transformer(
  548. input_ids,
  549. attention_mask=attention_mask,
  550. token_type_ids=token_type_ids,
  551. position_ids=position_ids,
  552. head_mask=head_mask,
  553. inputs_embeds=inputs_embeds,
  554. output_attentions=output_attentions,
  555. output_hidden_states=output_hidden_states,
  556. return_dict=return_dict,
  557. )
  558. sequence_output = outputs[0]
  559. prediction_scores = self.cls(sequence_output)
  560. masked_lm_loss = None
  561. if labels is not None:
  562. loss_fct = CrossEntropyLoss() # -100 index = padding token
  563. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  564. if not return_dict:
  565. output = (prediction_scores,) + outputs[2:]
  566. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  567. return MaskedLMOutput(
  568. loss=masked_lm_loss,
  569. logits=prediction_scores,
  570. hidden_states=outputs.hidden_states,
  571. attentions=outputs.attentions,
  572. )
  573. @add_start_docstrings(
  574. """
  575. SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  576. pooled output) e.g. for GLUE tasks.
  577. """,
  578. SQUEEZEBERT_START_DOCSTRING,
  579. )
  580. class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
  581. def __init__(self, config):
  582. super().__init__(config)
  583. self.num_labels = config.num_labels
  584. self.config = config
  585. self.transformer = SqueezeBertModel(config)
  586. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  587. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  588. # Initialize weights and apply final processing
  589. self.post_init()
  590. @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  591. @add_code_sample_docstrings(
  592. checkpoint=_CHECKPOINT_FOR_DOC,
  593. output_type=SequenceClassifierOutput,
  594. config_class=_CONFIG_FOR_DOC,
  595. )
  596. def forward(
  597. self,
  598. input_ids: Optional[torch.Tensor] = None,
  599. attention_mask: Optional[torch.Tensor] = None,
  600. token_type_ids: Optional[torch.Tensor] = None,
  601. position_ids: Optional[torch.Tensor] = None,
  602. head_mask: Optional[torch.Tensor] = None,
  603. inputs_embeds: Optional[torch.Tensor] = None,
  604. labels: Optional[torch.Tensor] = None,
  605. output_attentions: Optional[bool] = None,
  606. output_hidden_states: Optional[bool] = None,
  607. return_dict: Optional[bool] = None,
  608. ) -> Union[Tuple, SequenceClassifierOutput]:
  609. r"""
  610. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  611. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  612. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  613. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  614. """
  615. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  616. outputs = self.transformer(
  617. input_ids,
  618. attention_mask=attention_mask,
  619. token_type_ids=token_type_ids,
  620. position_ids=position_ids,
  621. head_mask=head_mask,
  622. inputs_embeds=inputs_embeds,
  623. output_attentions=output_attentions,
  624. output_hidden_states=output_hidden_states,
  625. return_dict=return_dict,
  626. )
  627. pooled_output = outputs[1]
  628. pooled_output = self.dropout(pooled_output)
  629. logits = self.classifier(pooled_output)
  630. loss = None
  631. if labels is not None:
  632. if self.config.problem_type is None:
  633. if self.num_labels == 1:
  634. self.config.problem_type = "regression"
  635. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  636. self.config.problem_type = "single_label_classification"
  637. else:
  638. self.config.problem_type = "multi_label_classification"
  639. if self.config.problem_type == "regression":
  640. loss_fct = MSELoss()
  641. if self.num_labels == 1:
  642. loss = loss_fct(logits.squeeze(), labels.squeeze())
  643. else:
  644. loss = loss_fct(logits, labels)
  645. elif self.config.problem_type == "single_label_classification":
  646. loss_fct = CrossEntropyLoss()
  647. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  648. elif self.config.problem_type == "multi_label_classification":
  649. loss_fct = BCEWithLogitsLoss()
  650. loss = loss_fct(logits, labels)
  651. if not return_dict:
  652. output = (logits,) + outputs[2:]
  653. return ((loss,) + output) if loss is not None else output
  654. return SequenceClassifierOutput(
  655. loss=loss,
  656. logits=logits,
  657. hidden_states=outputs.hidden_states,
  658. attentions=outputs.attentions,
  659. )
  660. @add_start_docstrings(
  661. """
  662. SqueezeBERT Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
  663. a softmax) e.g. for RocStories/SWAG tasks.
  664. """,
  665. SQUEEZEBERT_START_DOCSTRING,
  666. )
  667. class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel):
  668. def __init__(self, config):
  669. super().__init__(config)
  670. self.transformer = SqueezeBertModel(config)
  671. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  672. self.classifier = nn.Linear(config.hidden_size, 1)
  673. # Initialize weights and apply final processing
  674. self.post_init()
  675. @add_start_docstrings_to_model_forward(
  676. SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
  677. )
  678. @add_code_sample_docstrings(
  679. checkpoint=_CHECKPOINT_FOR_DOC,
  680. output_type=MultipleChoiceModelOutput,
  681. config_class=_CONFIG_FOR_DOC,
  682. )
  683. def forward(
  684. self,
  685. input_ids: Optional[torch.Tensor] = None,
  686. attention_mask: Optional[torch.Tensor] = None,
  687. token_type_ids: Optional[torch.Tensor] = None,
  688. position_ids: Optional[torch.Tensor] = None,
  689. head_mask: Optional[torch.Tensor] = None,
  690. inputs_embeds: Optional[torch.Tensor] = None,
  691. labels: Optional[torch.Tensor] = None,
  692. output_attentions: Optional[bool] = None,
  693. output_hidden_states: Optional[bool] = None,
  694. return_dict: Optional[bool] = None,
  695. ) -> Union[Tuple, MultipleChoiceModelOutput]:
  696. r"""
  697. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  698. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  699. num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
  700. *input_ids* above)
  701. """
  702. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  703. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  704. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  705. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  706. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  707. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  708. inputs_embeds = (
  709. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  710. if inputs_embeds is not None
  711. else None
  712. )
  713. outputs = self.transformer(
  714. input_ids,
  715. attention_mask=attention_mask,
  716. token_type_ids=token_type_ids,
  717. position_ids=position_ids,
  718. head_mask=head_mask,
  719. inputs_embeds=inputs_embeds,
  720. output_attentions=output_attentions,
  721. output_hidden_states=output_hidden_states,
  722. return_dict=return_dict,
  723. )
  724. pooled_output = outputs[1]
  725. pooled_output = self.dropout(pooled_output)
  726. logits = self.classifier(pooled_output)
  727. reshaped_logits = logits.view(-1, num_choices)
  728. loss = None
  729. if labels is not None:
  730. loss_fct = CrossEntropyLoss()
  731. loss = loss_fct(reshaped_logits, labels)
  732. if not return_dict:
  733. output = (reshaped_logits,) + outputs[2:]
  734. return ((loss,) + output) if loss is not None else output
  735. return MultipleChoiceModelOutput(
  736. loss=loss,
  737. logits=reshaped_logits,
  738. hidden_states=outputs.hidden_states,
  739. attentions=outputs.attentions,
  740. )
  741. @add_start_docstrings(
  742. """
  743. SqueezeBERT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
  744. for Named-Entity-Recognition (NER) tasks.
  745. """,
  746. SQUEEZEBERT_START_DOCSTRING,
  747. )
  748. class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
  749. def __init__(self, config):
  750. super().__init__(config)
  751. self.num_labels = config.num_labels
  752. self.transformer = SqueezeBertModel(config)
  753. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  754. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  755. # Initialize weights and apply final processing
  756. self.post_init()
  757. @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  758. @add_code_sample_docstrings(
  759. checkpoint=_CHECKPOINT_FOR_DOC,
  760. output_type=TokenClassifierOutput,
  761. config_class=_CONFIG_FOR_DOC,
  762. )
  763. def forward(
  764. self,
  765. input_ids: Optional[torch.Tensor] = None,
  766. attention_mask: Optional[torch.Tensor] = None,
  767. token_type_ids: Optional[torch.Tensor] = None,
  768. position_ids: Optional[torch.Tensor] = None,
  769. head_mask: Optional[torch.Tensor] = None,
  770. inputs_embeds: Optional[torch.Tensor] = None,
  771. labels: Optional[torch.Tensor] = None,
  772. output_attentions: Optional[bool] = None,
  773. output_hidden_states: Optional[bool] = None,
  774. return_dict: Optional[bool] = None,
  775. ) -> Union[Tuple, TokenClassifierOutput]:
  776. r"""
  777. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  778. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  779. """
  780. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  781. outputs = self.transformer(
  782. input_ids,
  783. attention_mask=attention_mask,
  784. token_type_ids=token_type_ids,
  785. position_ids=position_ids,
  786. head_mask=head_mask,
  787. inputs_embeds=inputs_embeds,
  788. output_attentions=output_attentions,
  789. output_hidden_states=output_hidden_states,
  790. return_dict=return_dict,
  791. )
  792. sequence_output = outputs[0]
  793. sequence_output = self.dropout(sequence_output)
  794. logits = self.classifier(sequence_output)
  795. loss = None
  796. if labels is not None:
  797. loss_fct = CrossEntropyLoss()
  798. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  799. if not return_dict:
  800. output = (logits,) + outputs[2:]
  801. return ((loss,) + output) if loss is not None else output
  802. return TokenClassifierOutput(
  803. loss=loss,
  804. logits=logits,
  805. hidden_states=outputs.hidden_states,
  806. attentions=outputs.attentions,
  807. )
  808. @add_start_docstrings(
  809. """
  810. SqueezeBERT Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
  811. linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  812. """,
  813. SQUEEZEBERT_START_DOCSTRING,
  814. )
  815. class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
  816. def __init__(self, config):
  817. super().__init__(config)
  818. self.num_labels = config.num_labels
  819. self.transformer = SqueezeBertModel(config)
  820. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  821. # Initialize weights and apply final processing
  822. self.post_init()
  823. @add_start_docstrings_to_model_forward(SQUEEZEBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
  824. @add_code_sample_docstrings(
  825. checkpoint=_CHECKPOINT_FOR_DOC,
  826. output_type=QuestionAnsweringModelOutput,
  827. config_class=_CONFIG_FOR_DOC,
  828. )
  829. def forward(
  830. self,
  831. input_ids: Optional[torch.Tensor] = None,
  832. attention_mask: Optional[torch.Tensor] = None,
  833. token_type_ids: Optional[torch.Tensor] = None,
  834. position_ids: Optional[torch.Tensor] = None,
  835. head_mask: Optional[torch.Tensor] = None,
  836. inputs_embeds: Optional[torch.Tensor] = None,
  837. start_positions: Optional[torch.Tensor] = None,
  838. end_positions: Optional[torch.Tensor] = None,
  839. output_attentions: Optional[bool] = None,
  840. output_hidden_states: Optional[bool] = None,
  841. return_dict: Optional[bool] = None,
  842. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  843. r"""
  844. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  845. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  846. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  847. are not taken into account for computing the loss.
  848. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  849. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  850. Positions are clamped to the length of the sequence (*sequence_length*). Position outside of the sequence
  851. are not taken into account for computing the loss.
  852. """
  853. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  854. outputs = self.transformer(
  855. input_ids,
  856. attention_mask=attention_mask,
  857. token_type_ids=token_type_ids,
  858. position_ids=position_ids,
  859. head_mask=head_mask,
  860. inputs_embeds=inputs_embeds,
  861. output_attentions=output_attentions,
  862. output_hidden_states=output_hidden_states,
  863. return_dict=return_dict,
  864. )
  865. sequence_output = outputs[0]
  866. logits = self.qa_outputs(sequence_output)
  867. start_logits, end_logits = logits.split(1, dim=-1)
  868. start_logits = start_logits.squeeze(-1).contiguous()
  869. end_logits = end_logits.squeeze(-1).contiguous()
  870. total_loss = None
  871. if start_positions is not None and end_positions is not None:
  872. # If we are on multi-GPU, split add a dimension
  873. if len(start_positions.size()) > 1:
  874. start_positions = start_positions.squeeze(-1)
  875. if len(end_positions.size()) > 1:
  876. end_positions = end_positions.squeeze(-1)
  877. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  878. ignored_index = start_logits.size(1)
  879. start_positions = start_positions.clamp(0, ignored_index)
  880. end_positions = end_positions.clamp(0, ignored_index)
  881. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  882. start_loss = loss_fct(start_logits, start_positions)
  883. end_loss = loss_fct(end_logits, end_positions)
  884. total_loss = (start_loss + end_loss) / 2
  885. if not return_dict:
  886. output = (start_logits, end_logits) + outputs[2:]
  887. return ((total_loss,) + output) if total_loss is not None else output
  888. return QuestionAnsweringModelOutput(
  889. loss=total_loss,
  890. start_logits=start_logits,
  891. end_logits=end_logits,
  892. hidden_states=outputs.hidden_states,
  893. attentions=outputs.attentions,
  894. )