modeling_blip_text.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # coding=utf-8
  2. # Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the BSD-3-clause license (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # https://opensource.org/licenses/BSD-3-Clause
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. from typing import List, Optional, Tuple, Union
  17. import torch
  18. import torch.utils.checkpoint
  19. from torch import Tensor, device, nn
  20. from torch.nn import CrossEntropyLoss
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_outputs import (
  24. BaseModelOutputWithPastAndCrossAttentions,
  25. BaseModelOutputWithPoolingAndCrossAttentions,
  26. CausalLMOutputWithCrossAttentions,
  27. )
  28. from ...modeling_utils import (
  29. PreTrainedModel,
  30. apply_chunking_to_forward,
  31. find_pruneable_heads_and_indices,
  32. prune_linear_layer,
  33. )
  34. from ...utils import logging
  35. from .configuration_blip import BlipTextConfig
  36. logger = logging.get_logger(__name__)
  37. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L52
  38. class BlipTextEmbeddings(nn.Module):
  39. """Construct the embeddings from word and position embeddings."""
  40. def __init__(self, config):
  41. super().__init__()
  42. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  43. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  44. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  45. # any TensorFlow checkpoint file
  46. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  47. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  48. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  49. self.register_buffer(
  50. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  51. )
  52. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  53. self.config = config
  54. def forward(
  55. self,
  56. input_ids: Optional[torch.LongTensor] = None,
  57. position_ids: Optional[torch.LongTensor] = None,
  58. inputs_embeds: Optional[torch.FloatTensor] = None,
  59. past_key_values_length: int = 0,
  60. ) -> torch.Tensor:
  61. if input_ids is not None:
  62. input_shape = input_ids.size()
  63. else:
  64. input_shape = inputs_embeds.size()[:-1]
  65. seq_length = input_shape[1]
  66. if position_ids is None:
  67. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  68. if inputs_embeds is None:
  69. input_ids = input_ids.to(self.word_embeddings.weight.device)
  70. inputs_embeds = self.word_embeddings(input_ids)
  71. embeddings = inputs_embeds
  72. if self.position_embedding_type == "absolute":
  73. position_embeddings = self.position_embeddings(position_ids)
  74. embeddings += position_embeddings
  75. embeddings = self.LayerNorm(embeddings)
  76. embeddings = self.dropout(embeddings)
  77. return embeddings
  78. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L97
  79. class BlipTextSelfAttention(nn.Module):
  80. def __init__(self, config, is_cross_attention):
  81. super().__init__()
  82. self.config = config
  83. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  84. raise ValueError(
  85. "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
  86. % (config.hidden_size, config.num_attention_heads)
  87. )
  88. self.num_attention_heads = config.num_attention_heads
  89. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  90. self.all_head_size = self.num_attention_heads * self.attention_head_size
  91. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  92. if is_cross_attention:
  93. self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  94. self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
  95. else:
  96. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  97. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  98. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  99. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  100. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  101. self.max_position_embeddings = config.max_position_embeddings
  102. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  103. def save_attn_gradients(self, attn_gradients):
  104. self.attn_gradients = attn_gradients
  105. def get_attn_gradients(self):
  106. return self.attn_gradients
  107. def save_attention_map(self, attention_map):
  108. self.attention_map = attention_map
  109. def get_attention_map(self):
  110. return self.attention_map
  111. def transpose_for_scores(self, x):
  112. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  113. x = x.view(*new_x_shape)
  114. return x.permute(0, 2, 1, 3)
  115. def forward(
  116. self,
  117. hidden_states: torch.Tensor,
  118. attention_mask: Optional[torch.FloatTensor] = None,
  119. head_mask: Optional[torch.FloatTensor] = None,
  120. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  121. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  122. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  123. output_attentions: Optional[bool] = False,
  124. ) -> Tuple[torch.Tensor]:
  125. mixed_query_layer = self.query(hidden_states)
  126. # If this is instantiated as a cross-attention module, the keys
  127. # and values come from an encoder; the attention mask needs to be
  128. # such that the encoder's padding tokens are not attended to.
  129. is_cross_attention = encoder_hidden_states is not None
  130. if is_cross_attention:
  131. key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
  132. value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
  133. attention_mask = encoder_attention_mask
  134. elif past_key_value is not None:
  135. key_layer = self.transpose_for_scores(self.key(hidden_states))
  136. value_layer = self.transpose_for_scores(self.value(hidden_states))
  137. key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
  138. value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
  139. else:
  140. key_layer = self.transpose_for_scores(self.key(hidden_states))
  141. value_layer = self.transpose_for_scores(self.value(hidden_states))
  142. query_layer = self.transpose_for_scores(mixed_query_layer)
  143. past_key_value = (key_layer, value_layer)
  144. # Take the dot product between "query" and "key" to get the raw attention scores.
  145. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  146. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  147. seq_length = hidden_states.size()[1]
  148. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  149. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  150. distance = position_ids_l - position_ids_r
  151. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  152. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  153. if self.position_embedding_type == "relative_key":
  154. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  155. attention_scores = attention_scores + relative_position_scores
  156. elif self.position_embedding_type == "relative_key_query":
  157. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  158. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  159. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  160. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  161. if attention_mask is not None:
  162. # Apply the attention mask is (precomputed for all layers in BlipTextModel forward() function)
  163. attention_scores = attention_scores + attention_mask.to(attention_scores.device)
  164. # Normalize the attention scores to probabilities.
  165. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  166. # This is actually dropping out entire tokens to attend to, which might
  167. # seem a bit unusual, but is taken from the original Transformer paper.
  168. attention_probs_dropped = self.dropout(attention_probs)
  169. # Mask heads if we want to
  170. if head_mask is not None:
  171. attention_probs_dropped = attention_probs_dropped * head_mask
  172. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  173. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  174. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  175. context_layer = context_layer.view(*new_context_layer_shape)
  176. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  177. outputs = outputs + (past_key_value,)
  178. return outputs
  179. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert -> BlipText
  180. class BlipTextSelfOutput(nn.Module):
  181. def __init__(self, config):
  182. super().__init__()
  183. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  184. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  185. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  186. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  187. hidden_states = self.dense(hidden_states)
  188. hidden_states = self.dropout(hidden_states)
  189. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  190. return hidden_states
  191. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#242
  192. class BlipTextAttention(nn.Module):
  193. def __init__(self, config, is_cross_attention=False):
  194. super().__init__()
  195. self.self = BlipTextSelfAttention(config, is_cross_attention)
  196. self.output = BlipTextSelfOutput(config)
  197. self.pruned_heads = set()
  198. def prune_heads(self, heads):
  199. if len(heads) == 0:
  200. return
  201. heads, index = find_pruneable_heads_and_indices(
  202. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  203. )
  204. # Prune linear layers
  205. self.self.query = prune_linear_layer(self.self.query, index)
  206. self.self.key = prune_linear_layer(self.self.key, index)
  207. self.self.value = prune_linear_layer(self.self.value, index)
  208. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  209. # Update hyper params and store pruned heads
  210. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  211. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  212. self.pruned_heads = self.pruned_heads.union(heads)
  213. def forward(
  214. self,
  215. hidden_states: torch.Tensor,
  216. attention_mask: Optional[torch.FloatTensor] = None,
  217. head_mask: Optional[torch.FloatTensor] = None,
  218. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  219. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  220. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  221. output_attentions: Optional[bool] = False,
  222. ) -> Tuple[torch.Tensor]:
  223. self_outputs = self.self(
  224. hidden_states,
  225. attention_mask,
  226. head_mask,
  227. encoder_hidden_states,
  228. encoder_attention_mask,
  229. past_key_value,
  230. output_attentions,
  231. )
  232. attention_output = self.output(self_outputs[0], hidden_states)
  233. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  234. return outputs
  235. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert -> BlipText
  236. class BlipTextIntermediate(nn.Module):
  237. def __init__(self, config):
  238. super().__init__()
  239. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  240. if isinstance(config.hidden_act, str):
  241. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  242. else:
  243. self.intermediate_act_fn = config.hidden_act
  244. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  245. hidden_states = self.dense(hidden_states)
  246. hidden_states = self.intermediate_act_fn(hidden_states)
  247. return hidden_states
  248. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert -> BlipText
  249. class BlipTextOutput(nn.Module):
  250. def __init__(self, config):
  251. super().__init__()
  252. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  253. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  254. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  255. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  256. hidden_states = self.dense(hidden_states)
  257. hidden_states = self.dropout(hidden_states)
  258. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  259. return hidden_states
  260. class BlipTextLayer(nn.Module):
  261. def __init__(self, config, layer_num):
  262. super().__init__()
  263. self.config = config
  264. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  265. self.seq_len_dim = 1
  266. self.attention = BlipTextAttention(config)
  267. self.layer_num = layer_num
  268. if self.config.is_decoder:
  269. self.crossattention = BlipTextAttention(config, is_cross_attention=self.config.is_decoder)
  270. self.intermediate = BlipTextIntermediate(config)
  271. self.output = BlipTextOutput(config)
  272. def forward(
  273. self,
  274. hidden_states: torch.Tensor,
  275. attention_mask: Optional[torch.FloatTensor] = None,
  276. head_mask: Optional[torch.FloatTensor] = None,
  277. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  278. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  279. past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  280. output_attentions: Optional[bool] = False,
  281. ) -> Tuple[torch.Tensor]:
  282. # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
  283. self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
  284. self_attention_outputs = self.attention(
  285. hidden_states,
  286. attention_mask,
  287. head_mask,
  288. output_attentions=output_attentions,
  289. past_key_value=self_attn_past_key_value,
  290. )
  291. attention_output = self_attention_outputs[0]
  292. outputs = self_attention_outputs[1:-1]
  293. present_key_value = self_attention_outputs[-1]
  294. if encoder_hidden_states is not None:
  295. cross_attention_outputs = self.crossattention(
  296. attention_output,
  297. attention_mask,
  298. head_mask,
  299. encoder_hidden_states,
  300. encoder_attention_mask,
  301. output_attentions=output_attentions,
  302. )
  303. attention_output = cross_attention_outputs[0]
  304. outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
  305. layer_output = apply_chunking_to_forward(
  306. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  307. )
  308. outputs = (layer_output,) + outputs
  309. outputs = outputs + (present_key_value,)
  310. return outputs
  311. def feed_forward_chunk(self, attention_output):
  312. intermediate_output = self.intermediate(attention_output)
  313. layer_output = self.output(intermediate_output, attention_output)
  314. return layer_output
  315. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L386
  316. class BlipTextEncoder(nn.Module):
  317. def __init__(self, config):
  318. super().__init__()
  319. self.config = config
  320. self.layer = nn.ModuleList([BlipTextLayer(config, i) for i in range(config.num_hidden_layers)])
  321. self.gradient_checkpointing = False
  322. def forward(
  323. self,
  324. hidden_states: torch.Tensor,
  325. attention_mask: Optional[torch.FloatTensor] = None,
  326. head_mask: Optional[torch.FloatTensor] = None,
  327. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  328. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  329. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  330. use_cache: Optional[bool] = None,
  331. output_attentions: Optional[bool] = False,
  332. output_hidden_states: Optional[bool] = False,
  333. return_dict: Optional[bool] = True,
  334. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  335. if self.gradient_checkpointing and self.training:
  336. if use_cache:
  337. logger.warning(
  338. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  339. )
  340. use_cache = False
  341. all_hidden_states = () if output_hidden_states else None
  342. all_self_attentions = () if output_attentions else None
  343. all_cross_attentions = () if output_attentions and self.config.is_decoder else None
  344. next_decoder_cache = () if use_cache else None
  345. for i in range(self.config.num_hidden_layers):
  346. layer_module = self.layer[i]
  347. if output_hidden_states:
  348. all_hidden_states = all_hidden_states + (hidden_states,)
  349. layer_head_mask = head_mask[i] if head_mask is not None else None
  350. past_key_value = past_key_values[i] if past_key_values is not None else None
  351. if self.gradient_checkpointing and self.training:
  352. layer_outputs = self._gradient_checkpointing_func(
  353. layer_module.__call__,
  354. hidden_states,
  355. attention_mask,
  356. layer_head_mask,
  357. encoder_hidden_states,
  358. encoder_attention_mask,
  359. past_key_value,
  360. output_attentions,
  361. )
  362. else:
  363. layer_outputs = layer_module(
  364. hidden_states,
  365. attention_mask,
  366. layer_head_mask,
  367. encoder_hidden_states,
  368. encoder_attention_mask,
  369. past_key_value,
  370. output_attentions,
  371. )
  372. hidden_states = layer_outputs[0]
  373. if use_cache:
  374. next_decoder_cache += (layer_outputs[-1],)
  375. if output_attentions:
  376. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  377. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  378. if output_hidden_states:
  379. all_hidden_states = all_hidden_states + (hidden_states,)
  380. if not return_dict:
  381. return tuple(
  382. v
  383. for v in [
  384. hidden_states,
  385. next_decoder_cache,
  386. all_hidden_states,
  387. all_self_attentions,
  388. all_cross_attentions,
  389. ]
  390. if v is not None
  391. )
  392. return BaseModelOutputWithPastAndCrossAttentions(
  393. last_hidden_state=hidden_states,
  394. past_key_values=next_decoder_cache,
  395. hidden_states=all_hidden_states,
  396. attentions=all_self_attentions,
  397. cross_attentions=all_cross_attentions,
  398. )
  399. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->BlipText
  400. class BlipTextPooler(nn.Module):
  401. def __init__(self, config):
  402. super().__init__()
  403. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  404. self.activation = nn.Tanh()
  405. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  406. # We "pool" the model by simply taking the hidden state corresponding
  407. # to the first token.
  408. first_token_tensor = hidden_states[:, 0]
  409. pooled_output = self.dense(first_token_tensor)
  410. pooled_output = self.activation(pooled_output)
  411. return pooled_output
  412. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->BlipText
  413. class BlipTextPredictionHeadTransform(nn.Module):
  414. def __init__(self, config):
  415. super().__init__()
  416. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  417. if isinstance(config.hidden_act, str):
  418. self.transform_act_fn = ACT2FN[config.hidden_act]
  419. else:
  420. self.transform_act_fn = config.hidden_act
  421. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  422. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  423. hidden_states = self.dense(hidden_states)
  424. hidden_states = self.transform_act_fn(hidden_states)
  425. hidden_states = self.LayerNorm(hidden_states)
  426. return hidden_states
  427. # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->BlipText
  428. class BlipTextLMPredictionHead(nn.Module):
  429. def __init__(self, config):
  430. super().__init__()
  431. self.transform = BlipTextPredictionHeadTransform(config)
  432. # The output weights are the same as the input embeddings, but there is
  433. # an output-only bias for each token.
  434. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  435. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  436. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  437. self.decoder.bias = self.bias
  438. def _tie_weights(self):
  439. self.decoder.bias = self.bias
  440. def forward(self, hidden_states):
  441. hidden_states = self.transform(hidden_states)
  442. hidden_states = self.decoder(hidden_states)
  443. return hidden_states
  444. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->BlipText
  445. class BlipTextOnlyMLMHead(nn.Module):
  446. def __init__(self, config):
  447. super().__init__()
  448. self.predictions = BlipTextLMPredictionHead(config)
  449. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  450. prediction_scores = self.predictions(sequence_output)
  451. return prediction_scores
  452. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L548
  453. class BlipTextPreTrainedModel(PreTrainedModel):
  454. """
  455. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  456. models.
  457. """
  458. config_class = BlipTextConfig
  459. base_model_prefix = "bert"
  460. _no_split_modules = []
  461. def _init_weights(self, module):
  462. """Initialize the weights"""
  463. if isinstance(module, (nn.Linear, nn.Embedding)):
  464. # Slightly different from the TF version which uses truncated_normal for initialization
  465. # cf https://github.com/pytorch/pytorch/pull/5617
  466. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  467. elif isinstance(module, nn.LayerNorm):
  468. module.bias.data.zero_()
  469. module.weight.data.fill_(1.0)
  470. if isinstance(module, nn.Linear) and module.bias is not None:
  471. module.bias.data.zero_()
  472. # Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
  473. class BlipTextModel(BlipTextPreTrainedModel):
  474. """
  475. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  476. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  477. all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  478. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. argument and `is_decoder` set to `True`; an
  479. `encoder_hidden_states` is then expected as an input to the forward pass.
  480. """
  481. def __init__(self, config, add_pooling_layer=True):
  482. super().__init__(config)
  483. self.config = config
  484. self.embeddings = BlipTextEmbeddings(config)
  485. self.encoder = BlipTextEncoder(config)
  486. self.pooler = BlipTextPooler(config) if add_pooling_layer else None
  487. self.post_init()
  488. def get_input_embeddings(self):
  489. return self.embeddings.word_embeddings
  490. def set_input_embeddings(self, value):
  491. self.embeddings.word_embeddings = value
  492. # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
  493. def _prune_heads(self, heads_to_prune):
  494. """
  495. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  496. class PreTrainedModel
  497. """
  498. for layer, heads in heads_to_prune.items():
  499. self.encoder.layer[layer].attention.prune_heads(heads)
  500. def get_extended_attention_mask(
  501. self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool
  502. ) -> Tensor:
  503. """
  504. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  505. Arguments:
  506. attention_mask (`torch.Tensor`):
  507. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  508. input_shape (`Tuple[int]`):
  509. The shape of the input to the model.
  510. device (`torch.device`):
  511. The device of the input to the model.
  512. Returns:
  513. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  514. """
  515. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  516. # ourselves in which case we just need to make it broadcastable to all heads.
  517. if attention_mask.dim() == 3:
  518. extended_attention_mask = attention_mask[:, None, :, :]
  519. elif attention_mask.dim() == 2:
  520. # Provided a padding mask of dimensions [batch_size, seq_length]
  521. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  522. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  523. if is_decoder:
  524. batch_size, seq_length = input_shape
  525. seq_ids = torch.arange(seq_length, device=device)
  526. causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
  527. # in case past_key_values are used we need to add a prefix ones mask to the causal mask
  528. # causal and attention masks must have same type with pytorch version < 1.3
  529. causal_mask = causal_mask.to(attention_mask.dtype)
  530. if causal_mask.shape[1] < attention_mask.shape[1]:
  531. prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
  532. causal_mask = torch.cat(
  533. [
  534. torch.ones(
  535. (batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
  536. ),
  537. causal_mask,
  538. ],
  539. axis=-1,
  540. )
  541. extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
  542. else:
  543. extended_attention_mask = attention_mask[:, None, None, :]
  544. else:
  545. raise ValueError(
  546. "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
  547. input_shape, attention_mask.shape
  548. )
  549. )
  550. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  551. # masked positions, this operation will create a tensor which is 0.0 for
  552. # positions we want to attend and -10000.0 for masked positions.
  553. # Since we are adding it to the raw scores before the softmax, this is
  554. # effectively the same as removing these entirely.
  555. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  556. extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
  557. return extended_attention_mask
  558. def forward(
  559. self,
  560. input_ids: Optional[torch.Tensor] = None,
  561. attention_mask: Optional[torch.Tensor] = None,
  562. position_ids: Optional[torch.Tensor] = None,
  563. head_mask: Optional[torch.Tensor] = None,
  564. inputs_embeds: Optional[torch.Tensor] = None,
  565. encoder_embeds: Optional[torch.Tensor] = None,
  566. encoder_hidden_states: Optional[torch.Tensor] = None,
  567. encoder_attention_mask: Optional[torch.Tensor] = None,
  568. past_key_values: Optional[List[torch.FloatTensor]] = None,
  569. use_cache: Optional[bool] = None,
  570. output_attentions: Optional[bool] = None,
  571. output_hidden_states: Optional[bool] = None,
  572. return_dict: Optional[bool] = None,
  573. is_decoder: Optional[bool] = False,
  574. ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  575. r"""
  576. encoder_hidden_states (`torch.FloatTensor`, *optional*):
  577. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
  578. the model is configured as a decoder.
  579. encoder_attention_mask (`torch.FloatTensor`, *optional*):
  580. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  581. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  582. - 1 for tokens that are **not masked**,
  583. - 0 for tokens that are **masked**.
  584. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*):
  585. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  586. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  587. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  588. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  589. use_cache (`bool`, *optional*):
  590. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  591. `past_key_values`).
  592. """
  593. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  594. output_hidden_states = (
  595. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  596. )
  597. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  598. if is_decoder:
  599. use_cache = use_cache if use_cache is not None else self.config.use_cache
  600. else:
  601. use_cache = False
  602. if input_ids is not None and inputs_embeds is not None:
  603. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  604. elif input_ids is not None:
  605. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  606. input_shape = input_ids.size()
  607. batch_size, seq_length = input_shape
  608. device = input_ids.device
  609. elif inputs_embeds is not None:
  610. input_shape = inputs_embeds.size()[:-1]
  611. batch_size, seq_length = input_shape
  612. device = inputs_embeds.device
  613. elif encoder_embeds is not None:
  614. input_shape = encoder_embeds.size()[:-1]
  615. batch_size, seq_length = input_shape
  616. device = encoder_embeds.device
  617. else:
  618. raise ValueError("You have to specify either input_ids or inputs_embeds or encoder_embeds")
  619. # past_key_values_length
  620. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  621. if attention_mask is None:
  622. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length))).to(device)
  623. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  624. # ourselves in which case we just need to make it broadcastable to all heads.
  625. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
  626. attention_mask, input_shape, device, is_decoder
  627. )
  628. # If a 2D or 3D attention mask is provided for the cross-attention
  629. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  630. if encoder_hidden_states is not None:
  631. if isinstance(encoder_hidden_states, list):
  632. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
  633. else:
  634. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  635. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  636. if isinstance(encoder_attention_mask, list):
  637. encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
  638. elif encoder_attention_mask is None:
  639. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  640. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  641. else:
  642. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  643. else:
  644. encoder_extended_attention_mask = None
  645. # Prepare head mask if needed
  646. # 1.0 in head_mask indicate we keep the head
  647. # attention_probs has shape bsz x n_heads x N x N
  648. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  649. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  650. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  651. if encoder_embeds is None:
  652. embedding_output = self.embeddings(
  653. input_ids=input_ids,
  654. position_ids=position_ids,
  655. inputs_embeds=inputs_embeds,
  656. past_key_values_length=past_key_values_length,
  657. )
  658. else:
  659. embedding_output = encoder_embeds
  660. encoder_outputs = self.encoder(
  661. embedding_output,
  662. attention_mask=extended_attention_mask,
  663. head_mask=head_mask,
  664. encoder_hidden_states=encoder_hidden_states,
  665. encoder_attention_mask=encoder_extended_attention_mask,
  666. past_key_values=past_key_values,
  667. use_cache=use_cache,
  668. output_attentions=output_attentions,
  669. output_hidden_states=output_hidden_states,
  670. return_dict=return_dict,
  671. )
  672. sequence_output = encoder_outputs[0]
  673. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  674. if not return_dict:
  675. return (sequence_output, pooled_output) + encoder_outputs[1:]
  676. return BaseModelOutputWithPoolingAndCrossAttentions(
  677. last_hidden_state=sequence_output,
  678. pooler_output=pooled_output,
  679. past_key_values=encoder_outputs.past_key_values,
  680. hidden_states=encoder_outputs.hidden_states,
  681. attentions=encoder_outputs.attentions,
  682. cross_attentions=encoder_outputs.cross_attentions,
  683. )
  684. # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811
  685. class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
  686. def __init__(self, config):
  687. super().__init__(config)
  688. self.bert = BlipTextModel(config, add_pooling_layer=False)
  689. self.cls = BlipTextOnlyMLMHead(config)
  690. self.label_smoothing = config.label_smoothing
  691. def get_output_embeddings(self):
  692. return self.cls.predictions.decoder
  693. def set_output_embeddings(self, new_embeddings):
  694. self.cls.predictions.decoder = new_embeddings
  695. self.cls.predictions.bias = new_embeddings.bias
  696. def forward(
  697. self,
  698. input_ids: Optional[torch.Tensor] = None,
  699. attention_mask: Optional[torch.Tensor] = None,
  700. position_ids: Optional[torch.Tensor] = None,
  701. head_mask: Optional[torch.Tensor] = None,
  702. inputs_embeds: Optional[torch.Tensor] = None,
  703. encoder_hidden_states: Optional[torch.Tensor] = None,
  704. encoder_attention_mask: Optional[torch.Tensor] = None,
  705. labels: Optional[torch.Tensor] = None,
  706. past_key_values: Optional[List[torch.Tensor]] = None,
  707. use_cache: Optional[bool] = None,
  708. output_attentions: Optional[bool] = None,
  709. output_hidden_states: Optional[bool] = None,
  710. return_dict: Optional[bool] = None,
  711. return_logits: Optional[bool] = False,
  712. is_decoder: Optional[bool] = True,
  713. reduction: Optional[str] = "mean",
  714. ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  715. r"""
  716. encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of
  717. hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is
  718. configured as a decoder.
  719. encoder_attention_mask (`torch.FloatTensor`, *optional*):
  720. Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  721. the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
  722. - 1 for tokens that are **not masked**,
  723. - 0 for tokens that are **masked**.
  724. labels (`torch.LongTensor`, *optional*):
  725. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  726. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  727. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  728. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*):
  729. Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
  730. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  731. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  732. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  733. use_cache (`bool`, *optional*):
  734. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  735. `past_key_values`).
  736. """
  737. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  738. if labels is not None:
  739. use_cache = False
  740. outputs = self.bert(
  741. input_ids,
  742. attention_mask=attention_mask,
  743. position_ids=position_ids,
  744. head_mask=head_mask,
  745. inputs_embeds=inputs_embeds,
  746. encoder_hidden_states=encoder_hidden_states,
  747. encoder_attention_mask=encoder_attention_mask,
  748. past_key_values=past_key_values,
  749. use_cache=use_cache,
  750. output_attentions=output_attentions,
  751. output_hidden_states=output_hidden_states,
  752. return_dict=return_dict,
  753. is_decoder=is_decoder,
  754. )
  755. sequence_output = outputs[0]
  756. prediction_scores = self.cls(sequence_output)
  757. if return_logits:
  758. return prediction_scores[:, :-1, :].contiguous()
  759. lm_loss = None
  760. if labels is not None:
  761. # we are doing next-token prediction; shift prediction scores and input ids by one
  762. shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
  763. labels = labels[:, 1:].contiguous().to(shifted_prediction_scores.device)
  764. loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=self.label_smoothing)
  765. lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  766. if reduction == "none":
  767. lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
  768. if not return_dict:
  769. output = (prediction_scores,) + outputs[2:]
  770. return ((lm_loss,) + output) if lm_loss is not None else output
  771. return CausalLMOutputWithCrossAttentions(
  772. loss=lm_loss,
  773. logits=prediction_scores,
  774. past_key_values=outputs.past_key_values,
  775. hidden_states=outputs.hidden_states,
  776. attentions=outputs.attentions,
  777. cross_attentions=outputs.cross_attentions,
  778. )
  779. def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
  780. # Overwrite -- hardcoded key return (`is_decoder=True`)
  781. input_shape = input_ids.shape
  782. # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
  783. if attention_mask is None:
  784. attention_mask = input_ids.new_ones(input_shape)
  785. # cut decoder_input_ids if past_key_values is used
  786. if past_key_values is not None:
  787. past_length = past_key_values[0][0].shape[2]
  788. # Some generation methods already pass only the last input ID
  789. if input_ids.shape[1] > past_length:
  790. remove_prefix_length = past_length
  791. else:
  792. # Default to old behavior: keep only final ID
  793. remove_prefix_length = input_ids.shape[1] - 1
  794. input_ids = input_ids[:, remove_prefix_length:]
  795. return {
  796. "input_ids": input_ids,
  797. "attention_mask": attention_mask,
  798. "past_key_values": past_key_values,
  799. "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
  800. "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
  801. "is_decoder": True,
  802. }
  803. def _reorder_cache(self, past_key_values, beam_idx):
  804. reordered_past = ()
  805. for layer_past in past_key_values:
  806. reordered_past += (
  807. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  808. )
  809. return reordered_past