modeling_opt.py 68 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487
  1. # coding=utf-8
  2. # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch OPT model."""
  16. from typing import List, Optional, Tuple, Union
  17. import torch
  18. import torch.utils.checkpoint
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import (
  24. _prepare_4d_causal_attention_mask,
  25. _prepare_4d_causal_attention_mask_for_sdpa,
  26. )
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPast,
  29. CausalLMOutputWithPast,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutputWithPast,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...utils import (
  35. add_code_sample_docstrings,
  36. add_start_docstrings,
  37. add_start_docstrings_to_model_forward,
  38. is_flash_attn_2_available,
  39. is_flash_attn_greater_or_equal_2_10,
  40. logging,
  41. replace_return_docstrings,
  42. )
  43. from .configuration_opt import OPTConfig
  44. if is_flash_attn_2_available():
  45. from ...modeling_flash_attention_utils import _flash_attention_forward
  46. logger = logging.get_logger(__name__)
  47. _CHECKPOINT_FOR_DOC = "facebook/opt-350m"
  48. _CONFIG_FOR_DOC = "OPTConfig"
  49. # Base model docstring
  50. _EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]
  51. # SequenceClassification docstring
  52. _CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "ArthurZ/opt-350m-dummy-sc"
  53. _SEQ_CLASS_EXPECTED_LOSS = 1.71
  54. _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_0'"
  55. class OPTLearnedPositionalEmbedding(nn.Embedding):
  56. """
  57. This module learns positional embeddings up to a fixed maximum size.
  58. """
  59. def __init__(self, num_embeddings: int, embedding_dim: int):
  60. # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
  61. # and adjust num_embeddings appropriately. Other models don't have this hack
  62. self.offset = 2
  63. super().__init__(num_embeddings + self.offset, embedding_dim)
  64. def forward(
  65. self,
  66. attention_mask: torch.LongTensor,
  67. past_key_values_length: int = 0,
  68. position_ids: Optional[torch.LongTensor] = None,
  69. ):
  70. """`input_ids_shape` is expected to be [bsz x seqlen]."""
  71. if position_ids is None:
  72. position_ids = torch.cumsum(attention_mask, dim=1)
  73. position_ids = (position_ids * attention_mask - 1).long()
  74. # cut positions if `past_key_values_length` is > 0
  75. position_ids = position_ids[:, past_key_values_length:]
  76. return super().forward(position_ids + self.offset)
  77. class OPTAttention(nn.Module):
  78. """Multi-headed attention from 'Attention Is All You Need' paper"""
  79. def __init__(
  80. self,
  81. config: OPTConfig,
  82. is_decoder: bool = False,
  83. **kwargs,
  84. ):
  85. super().__init__()
  86. self.config = config
  87. self.embed_dim = config.hidden_size
  88. self.num_heads = config.num_attention_heads
  89. self.dropout = config.attention_dropout
  90. self.enable_bias = config.enable_bias
  91. self.head_dim = self.embed_dim // self.num_heads
  92. self.is_causal = True
  93. if (self.head_dim * self.num_heads) != self.embed_dim:
  94. raise ValueError(
  95. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  96. f" and `num_heads`: {self.num_heads})."
  97. )
  98. self.scaling = self.head_dim**-0.5
  99. self.is_decoder = is_decoder
  100. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  101. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  102. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  103. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias)
  104. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
  105. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  106. def forward(
  107. self,
  108. hidden_states: torch.Tensor,
  109. key_value_states: Optional[torch.Tensor] = None,
  110. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  111. attention_mask: Optional[torch.Tensor] = None,
  112. layer_head_mask: Optional[torch.Tensor] = None,
  113. output_attentions: bool = False,
  114. # isn't needed in normal attention, but needed in flash attention so to keep the signature same
  115. position_ids: Optional[torch.Tensor] = None,
  116. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  117. """Input shape: Batch x Time x Channel"""
  118. # if key_value_states are provided this layer is used as a cross-attention layer
  119. # for the decoder
  120. is_cross_attention = key_value_states is not None
  121. bsz, tgt_len, _ = hidden_states.size()
  122. # get query proj
  123. query_states = self.q_proj(hidden_states) * self.scaling
  124. # get key, value proj
  125. if is_cross_attention and past_key_value is not None:
  126. # reuse k,v, cross_attentions
  127. key_states = past_key_value[0]
  128. value_states = past_key_value[1]
  129. elif is_cross_attention:
  130. # cross_attentions
  131. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  132. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  133. elif past_key_value is not None:
  134. # reuse k, v, self_attention
  135. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  136. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  137. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  138. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  139. else:
  140. # self_attention
  141. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  142. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  143. if self.is_decoder:
  144. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  145. # Further calls to cross_attention layer can then reuse all cross-attention
  146. # key/value_states (first "if" case)
  147. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  148. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  149. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  150. # if encoder bi-directional self-attention `past_key_value` is always `None`
  151. past_key_value = (key_states, value_states)
  152. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  153. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  154. key_states = key_states.view(*proj_shape)
  155. value_states = value_states.view(*proj_shape)
  156. src_len = key_states.size(1)
  157. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  158. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  159. raise ValueError(
  160. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  161. f" {attn_weights.size()}"
  162. )
  163. if attention_mask is not None:
  164. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  165. raise ValueError(
  166. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  167. )
  168. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  169. attn_weights = torch.max(
  170. attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
  171. )
  172. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  173. # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437
  174. if attn_weights.dtype == torch.float16:
  175. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16)
  176. else:
  177. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  178. if layer_head_mask is not None:
  179. if layer_head_mask.size() != (self.num_heads,):
  180. raise ValueError(
  181. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  182. f" {layer_head_mask.size()}"
  183. )
  184. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  185. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  186. if output_attentions:
  187. # this operation is a bit awkward, but it's required to
  188. # make sure that attn_weights keeps its gradient.
  189. # In order to do so, attn_weights have to be reshaped
  190. # twice and have to be reused in the following
  191. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  192. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  193. else:
  194. attn_weights_reshaped = None
  195. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  196. attn_output = torch.bmm(attn_probs, value_states)
  197. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  198. raise ValueError(
  199. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  200. f" {attn_output.size()}"
  201. )
  202. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  203. attn_output = attn_output.transpose(1, 2)
  204. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  205. # partitioned aross GPUs when using tensor-parallelism.
  206. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  207. attn_output = self.out_proj(attn_output)
  208. return attn_output, attn_weights_reshaped, past_key_value
  209. class OptFlashAttention2(OPTAttention):
  210. """
  211. OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
  212. The only required change would be on the forward pass where it needs to correctly call the public API of flash
  213. attention and deal with padding tokens in case the input contains any of them.
  214. """
  215. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  216. def __init__(self, *args, **kwargs):
  217. super().__init__(*args, **kwargs)
  218. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  219. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  220. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  221. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. key_value_states: Optional[torch.Tensor] = None,
  226. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  227. attention_mask: Optional[torch.Tensor] = None,
  228. layer_head_mask: Optional[torch.Tensor] = None,
  229. output_attentions: bool = False,
  230. position_ids: Optional[torch.Tensor] = None,
  231. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  232. """Input shape: Batch x Time x Channel"""
  233. # if key_value_states are provided this layer is used as a cross-attention layer
  234. # for the decoder
  235. is_cross_attention = key_value_states is not None
  236. bsz, _, _ = hidden_states.size()
  237. # get query proj
  238. query_states = self.q_proj(hidden_states)
  239. # get key, value proj
  240. if is_cross_attention and past_key_value is not None:
  241. # reuse k,v, cross_attentions
  242. key_states = past_key_value[0]
  243. value_states = past_key_value[1]
  244. elif is_cross_attention:
  245. # cross_attentions
  246. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  247. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  248. elif past_key_value is not None:
  249. # reuse k, v, self_attention
  250. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  251. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  252. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  253. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  254. else:
  255. # self_attention
  256. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  257. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  258. if self.is_decoder:
  259. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  260. # Further calls to cross_attention layer can then reuse all cross-attention
  261. # key/value_states (first "if" case)
  262. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  263. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  264. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  265. # if encoder bi-directional self-attention `past_key_value` is always `None`
  266. past_key_value = (key_states, value_states)
  267. query_length = query_states.shape[1]
  268. tgt_len = key_states.shape[-2]
  269. # Flash attention requires the input to have the shape
  270. # batch_size x seq_length x head_dim x hidden_dim
  271. query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
  272. key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
  273. value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
  274. attn_dropout = self.dropout if self.training else 0.0
  275. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  276. # therefore the input hidden states gets silently casted in float32. Hence, we need
  277. # cast them back in float16 just to be sure everything works as expected.
  278. input_dtype = query_states.dtype
  279. if input_dtype == torch.float32:
  280. if torch.is_autocast_enabled():
  281. target_dtype = torch.get_autocast_gpu_dtype()
  282. # Handle the case where the model is quantized
  283. elif hasattr(self.config, "_pre_quantization_dtype"):
  284. target_dtype = self.config._pre_quantization_dtype
  285. else:
  286. target_dtype = self.q_proj.weight.dtype
  287. logger.warning_once(
  288. f"The input hidden states seems to be silently casted in float32, this might be related to"
  289. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  290. f" {target_dtype}."
  291. )
  292. query_states = query_states.to(target_dtype)
  293. key_states = key_states.to(target_dtype)
  294. value_states = value_states.to(target_dtype)
  295. attn_output = _flash_attention_forward(
  296. query_states,
  297. key_states,
  298. value_states,
  299. attention_mask,
  300. query_length,
  301. position_ids=position_ids,
  302. dropout=attn_dropout,
  303. is_causal=self.is_causal,
  304. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  305. )
  306. attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
  307. attn_output = self.out_proj(attn_weights_reshaped)
  308. if not output_attentions:
  309. attn_weights_reshaped = None
  310. return attn_output, attn_weights_reshaped, past_key_value
  311. class OPTSdpaAttention(OPTAttention):
  312. """
  313. OPT sdpa attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
  314. The only required change would be on the forward pass where it needs to correctly call the public API of sdpa
  315. attention and deal with padding tokens in case the input contains any of them.
  316. """
  317. def forward(
  318. self,
  319. hidden_states: torch.Tensor,
  320. key_value_states: Optional[torch.Tensor] = None,
  321. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  322. attention_mask: Optional[torch.Tensor] = None,
  323. layer_head_mask: Optional[torch.Tensor] = None,
  324. output_attentions: bool = False,
  325. position_ids: Optional[torch.Tensor] = None,
  326. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  327. if output_attentions or layer_head_mask is not None:
  328. logger.warning_once(
  329. "OPTModel is using SDPA attention, which currently does not support output_attentions=True."
  330. 'failing back to eager attention. remove warning using attn_implementation="eager".'
  331. )
  332. return super().forward(
  333. hidden_states=hidden_states,
  334. attention_mask=attention_mask,
  335. layer_head_mask=layer_head_mask,
  336. past_key_value=past_key_value,
  337. output_attentions=output_attentions,
  338. key_value_states=key_value_states,
  339. ) # TODO after merge add position_ids=position_ids
  340. is_cross_attention = key_value_states is not None
  341. bsz, q_len, _ = hidden_states.size()
  342. query_states = self.q_proj(hidden_states) * self.scaling
  343. query_states = self._shape(query_states, -1, bsz)
  344. # get key, value proj
  345. if is_cross_attention and past_key_value is not None:
  346. # reuse k,v, cross_attentions
  347. key_states = past_key_value[0]
  348. value_states = past_key_value[1]
  349. elif is_cross_attention:
  350. # cross_attentions
  351. key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
  352. value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
  353. elif past_key_value is not None:
  354. # reuse k, v, self_attention
  355. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  356. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  357. key_states = torch.cat([past_key_value[0], key_states], dim=2)
  358. value_states = torch.cat([past_key_value[1], value_states], dim=2)
  359. else:
  360. # self_attention
  361. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  362. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  363. if self.is_decoder:
  364. # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
  365. # Further calls to cross_attention layer can then reuse all cross-attention
  366. # key/value_states (first "if" case)
  367. # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
  368. # all previous decoder key/value_states. Further calls to uni-directional self-attention
  369. # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
  370. # if encoder bi-directional self-attention `past_key_value` is always `None`
  371. past_key_value = (key_states, value_states)
  372. # shape now is (bsz, num_heads, seq_len, head_dim), all are continuous
  373. causal_mask = attention_mask
  374. if attention_mask is not None:
  375. causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
  376. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  377. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  378. is_causal = True if causal_mask is None and q_len > 1 else False
  379. attn_output = torch.nn.functional.scaled_dot_product_attention(
  380. query_states,
  381. key_states,
  382. value_states,
  383. attn_mask=causal_mask,
  384. dropout_p=self.dropout if self.training else 0.0,
  385. is_causal=is_causal,
  386. # this model uses the scaling factor in the query projection for some reason, but not in Q@K^T
  387. # so we need to scale to remove scaling in SDPA to have similar results with eager.
  388. # Maybe needs a change in the model to remove scaling in query projection
  389. scale=1.0,
  390. )
  391. attn_output = attn_output.transpose(1, 2).contiguous()
  392. attn_output = attn_output.view(bsz, q_len, -1)
  393. attn_output = self.out_proj(attn_output)
  394. return attn_output, None, past_key_value
  395. OPT_ATTENTION_CLASSES = {
  396. "eager": OPTAttention,
  397. "flash_attention_2": OptFlashAttention2,
  398. "sdpa": OPTSdpaAttention,
  399. }
  400. class OPTDecoderLayer(nn.Module):
  401. def __init__(self, config: OPTConfig):
  402. super().__init__()
  403. self.embed_dim = config.hidden_size
  404. self.self_attn = OPT_ATTENTION_CLASSES[config._attn_implementation](config=config, is_decoder=True)
  405. self.do_layer_norm_before = config.do_layer_norm_before
  406. self.dropout = config.dropout
  407. self.activation_fn = ACT2FN[config.activation_function]
  408. self.self_attn_layer_norm = nn.LayerNorm(
  409. self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
  410. )
  411. self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
  412. self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
  413. self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
  414. def forward(
  415. self,
  416. hidden_states: torch.Tensor,
  417. attention_mask: Optional[torch.Tensor] = None,
  418. layer_head_mask: Optional[torch.Tensor] = None,
  419. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  420. output_attentions: Optional[bool] = False,
  421. use_cache: Optional[bool] = False,
  422. position_ids: Optional[torch.LongTensor] = None,
  423. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  424. """
  425. Args:
  426. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  427. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  428. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  429. layer_head_mask (`torch.FloatTensor`, *optional*): mask for attention heads in a given layer of size
  430. `(encoder_attention_heads,)`.
  431. output_attentions (`bool`, *optional*):
  432. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  433. returned tensors for more detail.
  434. use_cache (`bool`, *optional*):
  435. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  436. (see `past_key_values`).
  437. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  438. """
  439. residual = hidden_states
  440. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  441. if self.do_layer_norm_before:
  442. hidden_states = self.self_attn_layer_norm(hidden_states)
  443. # Self Attention
  444. hidden_states, self_attn_weights, present_key_value = self.self_attn(
  445. hidden_states=hidden_states,
  446. past_key_value=past_key_value,
  447. position_ids=position_ids,
  448. attention_mask=attention_mask,
  449. layer_head_mask=layer_head_mask,
  450. output_attentions=output_attentions,
  451. )
  452. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  453. hidden_states = residual + hidden_states
  454. # 350m applies layer norm AFTER attention
  455. if not self.do_layer_norm_before:
  456. hidden_states = self.self_attn_layer_norm(hidden_states)
  457. # Fully Connected
  458. hidden_states_shape = hidden_states.shape
  459. hidden_states = hidden_states.reshape(-1, hidden_states.size(-1))
  460. residual = hidden_states
  461. # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
  462. if self.do_layer_norm_before:
  463. hidden_states = self.final_layer_norm(hidden_states)
  464. hidden_states = self.fc1(hidden_states)
  465. hidden_states = self.activation_fn(hidden_states)
  466. hidden_states = self.fc2(hidden_states)
  467. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  468. hidden_states = (residual + hidden_states).view(hidden_states_shape)
  469. # 350m applies layer norm AFTER attention
  470. if not self.do_layer_norm_before:
  471. hidden_states = self.final_layer_norm(hidden_states)
  472. outputs = (hidden_states,)
  473. if output_attentions:
  474. outputs += (self_attn_weights,)
  475. if use_cache:
  476. outputs += (present_key_value,)
  477. return outputs
  478. OPT_START_DOCSTRING = r"""
  479. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  480. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  481. etc.)
  482. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  483. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  484. and behavior.
  485. Parameters:
  486. config ([`OPTConfig`]):
  487. Model configuration class with all the parameters of the model. Initializing with a config file does not
  488. load the weights associated with the model, only the configuration. Check out the
  489. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  490. """
  491. @add_start_docstrings(
  492. "The bare OPT Model outputting raw hidden-states without any specific head on top.",
  493. OPT_START_DOCSTRING,
  494. )
  495. class OPTPreTrainedModel(PreTrainedModel):
  496. config_class = OPTConfig
  497. base_model_prefix = "model"
  498. supports_gradient_checkpointing = True
  499. _no_split_modules = ["OPTDecoderLayer"]
  500. _supports_flash_attn_2 = True
  501. _supports_sdpa = True
  502. def _init_weights(self, module):
  503. std = self.config.init_std
  504. if isinstance(module, nn.Linear):
  505. module.weight.data.normal_(mean=0.0, std=std)
  506. if module.bias is not None:
  507. module.bias.data.zero_()
  508. elif isinstance(module, nn.Embedding):
  509. module.weight.data.normal_(mean=0.0, std=std)
  510. if module.padding_idx is not None:
  511. module.weight.data[module.padding_idx].zero_()
  512. OPT_INPUTS_DOCSTRING = r"""
  513. Args:
  514. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  515. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  516. it.
  517. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  518. [`PreTrainedTokenizer.__call__`] for details.
  519. [What are input IDs?](../glossary#input-ids)
  520. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  521. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  522. - 1 for tokens that are **not masked**,
  523. - 0 for tokens that are **masked**.
  524. [What are attention masks?](../glossary#attention-mask)
  525. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  526. [`PreTrainedTokenizer.__call__`] for details.
  527. If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  528. `past_key_values`).
  529. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  530. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  531. information on the default strategy.
  532. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  533. Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
  534. - 1 indicates the head is **not masked**,
  535. - 0 indicates the head is **masked**.
  536. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  537. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
  538. `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
  539. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
  540. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  541. blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  542. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
  543. don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
  544. `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  545. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  546. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  547. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  548. model's internal embedding lookup matrix.
  549. use_cache (`bool`, *optional*):
  550. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  551. `past_key_values`).
  552. output_attentions (`bool`, *optional*):
  553. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  554. tensors for more detail.
  555. output_hidden_states (`bool`, *optional*):
  556. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  557. more detail.
  558. return_dict (`bool`, *optional*):
  559. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  560. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  561. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  562. config.n_positions - 1]`. for padding use -1.
  563. [What are position IDs?](../glossary#position-ids)
  564. """
  565. class OPTDecoder(OPTPreTrainedModel):
  566. """
  567. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
  568. Args:
  569. config: OPTConfig
  570. """
  571. def __init__(self, config: OPTConfig):
  572. super().__init__(config)
  573. self.dropout = config.dropout
  574. self.layerdrop = config.layerdrop
  575. self.padding_idx = config.pad_token_id
  576. self.max_target_positions = config.max_position_embeddings
  577. self.vocab_size = config.vocab_size
  578. self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
  579. self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
  580. if config.word_embed_proj_dim != config.hidden_size:
  581. self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
  582. else:
  583. self.project_out = None
  584. if config.word_embed_proj_dim != config.hidden_size:
  585. self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
  586. else:
  587. self.project_in = None
  588. # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
  589. # with checkpoints that have been fine-tuned before transformers v4.20.1
  590. # see https://github.com/facebookresearch/metaseq/pull/164
  591. if config.do_layer_norm_before and not config._remove_final_layer_norm:
  592. self.final_layer_norm = nn.LayerNorm(
  593. config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine
  594. )
  595. else:
  596. self.final_layer_norm = None
  597. self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
  598. self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
  599. self._use_sdpa = config._attn_implementation == "sdpa"
  600. self.gradient_checkpointing = False
  601. # Initialize weights and apply final processing
  602. self.post_init()
  603. def get_input_embeddings(self):
  604. return self.embed_tokens
  605. def set_input_embeddings(self, value):
  606. self.embed_tokens = value
  607. def _update_causal_mask(
  608. self,
  609. inputs_embeds: torch.Tensor,
  610. input_shape: Tuple[int, int],
  611. past_key_values_length: int,
  612. attention_mask: Optional[torch.Tensor] = None,
  613. head_mask: Optional[torch.Tensor] = None,
  614. output_attentions: Optional[bool] = None,
  615. ):
  616. """
  617. Updates the causal mask for the decoder.
  618. """
  619. batch_size, seq_length = input_shape
  620. mask_seq_length = past_key_values_length + seq_length
  621. if self._use_flash_attention_2:
  622. # 2d mask is passed through the layers
  623. causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  624. attention_mask = (
  625. torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  626. if attention_mask is None
  627. else attention_mask
  628. )
  629. return causal_attention_mask, attention_mask
  630. if attention_mask is None:
  631. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  632. elif attention_mask.shape[1] != mask_seq_length:
  633. raise ValueError(
  634. f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
  635. f"{mask_seq_length} (sum of the lengths of current and past inputs)"
  636. )
  637. if self._use_sdpa and not output_attentions and head_mask is None:
  638. causal_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  639. attention_mask, input_shape, inputs_embeds, past_key_values_length
  640. )
  641. else:
  642. causal_attention_mask = _prepare_4d_causal_attention_mask(
  643. attention_mask, input_shape, inputs_embeds, past_key_values_length
  644. )
  645. return causal_attention_mask, attention_mask
  646. def forward(
  647. self,
  648. input_ids: torch.LongTensor = None,
  649. attention_mask: Optional[torch.Tensor] = None,
  650. head_mask: Optional[torch.Tensor] = None,
  651. past_key_values: Optional[List[torch.FloatTensor]] = None,
  652. inputs_embeds: Optional[torch.FloatTensor] = None,
  653. use_cache: Optional[bool] = None,
  654. output_attentions: Optional[bool] = None,
  655. output_hidden_states: Optional[bool] = None,
  656. return_dict: Optional[bool] = None,
  657. position_ids: Optional[torch.LongTensor] = None,
  658. ) -> Union[Tuple, BaseModelOutputWithPast]:
  659. r"""
  660. Args:
  661. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  662. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  663. provide it.
  664. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  665. [`PreTrainedTokenizer.__call__`] for details.
  666. [What are input IDs?](../glossary#input-ids)
  667. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  668. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  669. - 1 for tokens that are **not masked**,
  670. - 0 for tokens that are **masked**.
  671. [What are attention masks?](../glossary#attention-mask)
  672. head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
  673. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  674. - 1 indicates the head is **not masked**,
  675. - 0 indicates the head is **masked**.
  676. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  677. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  678. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  679. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  680. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  681. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  682. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  683. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  684. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  685. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  686. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  687. than the model's internal embedding lookup matrix.
  688. output_attentions (`bool`, *optional*):
  689. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  690. returned tensors for more detail.
  691. output_hidden_states (`bool`, *optional*):
  692. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  693. for more detail.
  694. return_dict (`bool`, *optional*):
  695. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  696. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  697. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  698. config.n_positions - 1]`. for padding use -1.
  699. [What are position IDs?](../glossary#position-ids)
  700. """
  701. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  702. output_hidden_states = (
  703. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  704. )
  705. use_cache = use_cache if use_cache is not None else self.config.use_cache
  706. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  707. # retrieve input_ids and inputs_embeds
  708. if input_ids is not None and inputs_embeds is not None:
  709. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  710. elif input_ids is not None:
  711. input_shape = input_ids.size()
  712. input_ids = input_ids.view(-1, input_shape[-1])
  713. elif inputs_embeds is not None:
  714. input_shape = inputs_embeds.size()[:-1]
  715. else:
  716. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  717. if inputs_embeds is None:
  718. inputs_embeds = self.embed_tokens(input_ids)
  719. past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
  720. causal_attention_mask, attention_mask = self._update_causal_mask(
  721. inputs_embeds, input_shape, past_key_values_length, attention_mask, head_mask, output_attentions
  722. )
  723. # embed positions
  724. if position_ids is None:
  725. position_ids = torch.cumsum(attention_mask, dim=1)
  726. position_ids = (position_ids * attention_mask - 1).long()
  727. # cut positions if `past_key_values_length` is > 0
  728. position_ids = position_ids[:, past_key_values_length:]
  729. pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
  730. if self.project_in is not None:
  731. inputs_embeds = self.project_in(inputs_embeds)
  732. hidden_states = inputs_embeds + pos_embeds
  733. if self.gradient_checkpointing and self.training:
  734. if use_cache:
  735. logger.warning_once(
  736. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  737. )
  738. use_cache = False
  739. # decoder layers
  740. all_hidden_states = () if output_hidden_states else None
  741. all_self_attns = () if output_attentions else None
  742. next_decoder_cache = () if use_cache else None
  743. # check if head_mask has a correct number of layers specified if desired
  744. for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
  745. if attn_mask is not None:
  746. if attn_mask.size()[0] != (len(self.layers)):
  747. raise ValueError(
  748. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  749. f" {head_mask.size()[0]}."
  750. )
  751. for idx, decoder_layer in enumerate(self.layers):
  752. # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
  753. if output_hidden_states:
  754. all_hidden_states += (hidden_states,)
  755. if self.training:
  756. dropout_probability = torch.rand([])
  757. if dropout_probability < self.layerdrop:
  758. continue
  759. past_key_value = past_key_values[idx] if past_key_values is not None else None
  760. if self.gradient_checkpointing and self.training:
  761. layer_outputs = self._gradient_checkpointing_func(
  762. decoder_layer.__call__,
  763. hidden_states,
  764. causal_attention_mask,
  765. head_mask[idx] if head_mask is not None else None,
  766. None,
  767. output_attentions,
  768. use_cache,
  769. position_ids,
  770. )
  771. else:
  772. layer_outputs = decoder_layer(
  773. hidden_states,
  774. attention_mask=causal_attention_mask,
  775. position_ids=position_ids,
  776. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  777. past_key_value=past_key_value,
  778. output_attentions=output_attentions,
  779. use_cache=use_cache,
  780. )
  781. hidden_states = layer_outputs[0]
  782. if use_cache:
  783. next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
  784. if output_attentions:
  785. all_self_attns += (layer_outputs[1],)
  786. if self.final_layer_norm is not None:
  787. hidden_states = self.final_layer_norm(hidden_states)
  788. if self.project_out is not None:
  789. hidden_states = self.project_out(hidden_states)
  790. # add hidden states from the last decoder layer
  791. if output_hidden_states:
  792. all_hidden_states += (hidden_states,)
  793. next_cache = next_decoder_cache if use_cache else None
  794. if not return_dict:
  795. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  796. return BaseModelOutputWithPast(
  797. last_hidden_state=hidden_states,
  798. past_key_values=next_cache,
  799. hidden_states=all_hidden_states,
  800. attentions=all_self_attns,
  801. )
  802. @add_start_docstrings(
  803. "The bare OPT Model outputting raw hidden-states without any specific head on top.",
  804. OPT_START_DOCSTRING,
  805. )
  806. class OPTModel(OPTPreTrainedModel):
  807. def __init__(self, config: OPTConfig):
  808. super().__init__(config)
  809. self.decoder = OPTDecoder(config)
  810. # Initialize weights and apply final processing
  811. self.post_init()
  812. def get_input_embeddings(self):
  813. return self.decoder.embed_tokens
  814. def set_input_embeddings(self, value):
  815. self.decoder.embed_tokens = value
  816. def get_decoder(self):
  817. return self.decoder
  818. @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
  819. @add_code_sample_docstrings(
  820. checkpoint=_CHECKPOINT_FOR_DOC,
  821. output_type=BaseModelOutputWithPast,
  822. config_class=_CONFIG_FOR_DOC,
  823. expected_output=_EXPECTED_OUTPUT_SHAPE,
  824. )
  825. def forward(
  826. self,
  827. input_ids: torch.LongTensor = None,
  828. attention_mask: Optional[torch.Tensor] = None,
  829. head_mask: Optional[torch.Tensor] = None,
  830. past_key_values: Optional[List[torch.FloatTensor]] = None,
  831. inputs_embeds: Optional[torch.FloatTensor] = None,
  832. use_cache: Optional[bool] = None,
  833. output_attentions: Optional[bool] = None,
  834. output_hidden_states: Optional[bool] = None,
  835. return_dict: Optional[bool] = None,
  836. position_ids: Optional[torch.LongTensor] = None,
  837. ) -> Union[Tuple, BaseModelOutputWithPast]:
  838. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  839. output_hidden_states = (
  840. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  841. )
  842. use_cache = use_cache if use_cache is not None else self.config.use_cache
  843. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  844. # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
  845. decoder_outputs = self.decoder(
  846. input_ids=input_ids,
  847. attention_mask=attention_mask,
  848. position_ids=position_ids,
  849. head_mask=head_mask,
  850. past_key_values=past_key_values,
  851. inputs_embeds=inputs_embeds,
  852. use_cache=use_cache,
  853. output_attentions=output_attentions,
  854. output_hidden_states=output_hidden_states,
  855. return_dict=return_dict,
  856. )
  857. if not return_dict:
  858. return decoder_outputs
  859. return BaseModelOutputWithPast(
  860. last_hidden_state=decoder_outputs.last_hidden_state,
  861. past_key_values=decoder_outputs.past_key_values,
  862. hidden_states=decoder_outputs.hidden_states,
  863. attentions=decoder_outputs.attentions,
  864. )
  865. class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin):
  866. _tied_weights_keys = ["lm_head.weight"]
  867. def __init__(self, config):
  868. super().__init__(config)
  869. self.model = OPTModel(config)
  870. # the lm_head weight is automatically tied to the embed tokens weight
  871. self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
  872. # Initialize weights and apply final processing
  873. self.post_init()
  874. def get_input_embeddings(self):
  875. return self.model.decoder.embed_tokens
  876. def set_input_embeddings(self, value):
  877. self.model.decoder.embed_tokens = value
  878. def get_output_embeddings(self):
  879. return self.lm_head
  880. def set_output_embeddings(self, new_embeddings):
  881. self.lm_head = new_embeddings
  882. def set_decoder(self, decoder):
  883. self.model.decoder = decoder
  884. def get_decoder(self):
  885. return self.model.decoder
  886. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  887. def forward(
  888. self,
  889. input_ids: torch.LongTensor = None,
  890. attention_mask: Optional[torch.Tensor] = None,
  891. head_mask: Optional[torch.Tensor] = None,
  892. past_key_values: Optional[List[torch.FloatTensor]] = None,
  893. inputs_embeds: Optional[torch.FloatTensor] = None,
  894. labels: Optional[torch.LongTensor] = None,
  895. use_cache: Optional[bool] = None,
  896. output_attentions: Optional[bool] = None,
  897. output_hidden_states: Optional[bool] = None,
  898. return_dict: Optional[bool] = None,
  899. position_ids: Optional[torch.LongTensor] = None,
  900. ) -> Union[Tuple, CausalLMOutputWithPast]:
  901. r"""
  902. Args:
  903. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  904. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  905. provide it.
  906. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  907. [`PreTrainedTokenizer.__call__`] for details.
  908. [What are input IDs?](../glossary#input-ids)
  909. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  910. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  911. - 1 for tokens that are **not masked**,
  912. - 0 for tokens that are **masked**.
  913. [What are attention masks?](../glossary#attention-mask)
  914. head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
  915. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  916. - 1 indicates the head is **not masked**,
  917. - 0 indicates the head is **masked**.
  918. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  919. Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  920. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
  921. shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
  922. tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
  923. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  924. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  925. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  926. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  927. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  928. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  929. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  930. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  931. than the model's internal embedding lookup matrix.
  932. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  933. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  934. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  935. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  936. use_cache (`bool`, *optional*):
  937. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  938. (see `past_key_values`).
  939. output_attentions (`bool`, *optional*):
  940. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  941. returned tensors for more detail.
  942. output_hidden_states (`bool`, *optional*):
  943. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  944. for more detail.
  945. return_dict (`bool`, *optional*):
  946. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  947. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  948. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  949. config.n_positions - 1]`. for padding use -1.
  950. [What are position IDs?](../glossary#position-ids)
  951. Returns:
  952. Example:
  953. ```python
  954. >>> from transformers import AutoTokenizer, OPTForCausalLM
  955. >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
  956. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  957. >>> prompt = "Hey, are you conscious? Can you talk to me?"
  958. >>> inputs = tokenizer(prompt, return_tensors="pt")
  959. >>> # Generate
  960. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  961. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  962. "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
  963. ```"""
  964. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  965. output_hidden_states = (
  966. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  967. )
  968. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  969. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  970. outputs = self.model.decoder(
  971. input_ids=input_ids,
  972. attention_mask=attention_mask,
  973. position_ids=position_ids,
  974. head_mask=head_mask,
  975. past_key_values=past_key_values,
  976. inputs_embeds=inputs_embeds,
  977. use_cache=use_cache,
  978. output_attentions=output_attentions,
  979. output_hidden_states=output_hidden_states,
  980. return_dict=return_dict,
  981. )
  982. logits = self.lm_head(outputs[0]).contiguous()
  983. loss = None
  984. if labels is not None:
  985. # move labels to correct device to enable model parallelism
  986. labels = labels.to(logits.device)
  987. # Shift so that tokens < n predict n
  988. shift_logits = logits[..., :-1, :].contiguous()
  989. shift_labels = labels[..., 1:].contiguous()
  990. # Flatten the tokens
  991. loss_fct = CrossEntropyLoss()
  992. loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
  993. if not return_dict:
  994. output = (logits,) + outputs[1:]
  995. return (loss,) + output if loss is not None else output
  996. return CausalLMOutputWithPast(
  997. loss=loss,
  998. logits=logits,
  999. past_key_values=outputs.past_key_values,
  1000. hidden_states=outputs.hidden_states,
  1001. attentions=outputs.attentions,
  1002. )
  1003. @staticmethod
  1004. def _reorder_cache(past_key_values, beam_idx):
  1005. reordered_past = ()
  1006. for layer_past in past_key_values:
  1007. reordered_past += (
  1008. tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
  1009. )
  1010. return reordered_past
  1011. @add_start_docstrings(
  1012. """
  1013. The OPT Model transformer with a sequence classification head on top (linear layer).
  1014. [`OPTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1015. (e.g. GPT-2) do.
  1016. Since it does classification on the last token, it requires to know the position of the last token. If a
  1017. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1018. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1019. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1020. each row of the batch).
  1021. """,
  1022. OPT_START_DOCSTRING,
  1023. )
  1024. class OPTForSequenceClassification(OPTPreTrainedModel):
  1025. def __init__(self, config: OPTConfig):
  1026. super().__init__(config)
  1027. self.num_labels = config.num_labels
  1028. self.model = OPTModel(config)
  1029. self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
  1030. # Initialize weights and apply final processing
  1031. self.post_init()
  1032. @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
  1033. @add_code_sample_docstrings(
  1034. checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
  1035. output_type=SequenceClassifierOutputWithPast,
  1036. config_class=_CONFIG_FOR_DOC,
  1037. expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
  1038. expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
  1039. )
  1040. def forward(
  1041. self,
  1042. input_ids: Optional[torch.LongTensor] = None,
  1043. attention_mask: Optional[torch.FloatTensor] = None,
  1044. head_mask: Optional[torch.FloatTensor] = None,
  1045. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1046. inputs_embeds: Optional[torch.FloatTensor] = None,
  1047. labels: Optional[torch.LongTensor] = None,
  1048. use_cache: Optional[bool] = None,
  1049. output_attentions: Optional[bool] = None,
  1050. output_hidden_states: Optional[bool] = None,
  1051. return_dict: Optional[bool] = None,
  1052. position_ids: Optional[torch.LongTensor] = None,
  1053. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1054. r"""
  1055. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1056. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1057. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1058. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1059. """
  1060. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1061. transformer_outputs = self.model(
  1062. input_ids,
  1063. past_key_values=past_key_values,
  1064. attention_mask=attention_mask,
  1065. position_ids=position_ids,
  1066. head_mask=head_mask,
  1067. inputs_embeds=inputs_embeds,
  1068. use_cache=use_cache,
  1069. output_attentions=output_attentions,
  1070. output_hidden_states=output_hidden_states,
  1071. return_dict=return_dict,
  1072. )
  1073. hidden_states = transformer_outputs[0]
  1074. logits = self.score(hidden_states)
  1075. if input_ids is not None:
  1076. batch_size, sequence_length = input_ids.shape[:2]
  1077. else:
  1078. batch_size, sequence_length = inputs_embeds.shape[:2]
  1079. if self.config.pad_token_id is None:
  1080. sequence_lengths = -1
  1081. else:
  1082. if input_ids is not None:
  1083. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1084. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1085. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1086. sequence_lengths = sequence_lengths.to(logits.device)
  1087. else:
  1088. sequence_lengths = -1
  1089. logger.warning_once(
  1090. f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
  1091. "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
  1092. )
  1093. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1094. loss = None
  1095. if labels is not None:
  1096. if self.config.problem_type is None:
  1097. if self.num_labels == 1:
  1098. self.config.problem_type = "regression"
  1099. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1100. self.config.problem_type = "single_label_classification"
  1101. else:
  1102. self.config.problem_type = "multi_label_classification"
  1103. if self.config.problem_type == "regression":
  1104. loss_fct = MSELoss()
  1105. if self.num_labels == 1:
  1106. loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
  1107. else:
  1108. loss = loss_fct(pooled_logits, labels)
  1109. elif self.config.problem_type == "single_label_classification":
  1110. loss_fct = CrossEntropyLoss()
  1111. loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
  1112. elif self.config.problem_type == "multi_label_classification":
  1113. loss_fct = BCEWithLogitsLoss()
  1114. loss = loss_fct(pooled_logits, labels)
  1115. if not return_dict:
  1116. output = (pooled_logits,) + transformer_outputs[1:]
  1117. return ((loss,) + output) if loss is not None else output
  1118. return SequenceClassifierOutputWithPast(
  1119. loss=loss,
  1120. logits=pooled_logits,
  1121. past_key_values=transformer_outputs.past_key_values,
  1122. hidden_states=transformer_outputs.hidden_states,
  1123. attentions=transformer_outputs.attentions,
  1124. )
  1125. def get_input_embeddings(self):
  1126. return self.model.decoder.embed_tokens
  1127. def set_input_embeddings(self, value):
  1128. self.model.decoder.embed_tokens = value
  1129. @add_start_docstrings(
  1130. """
  1131. The OPT Model transformer with a span classification head on top for extractive question-answering tasks like SQuAD
  1132. (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1133. """,
  1134. OPT_START_DOCSTRING,
  1135. )
  1136. class OPTForQuestionAnswering(OPTPreTrainedModel):
  1137. def __init__(self, config: OPTConfig):
  1138. super().__init__(config)
  1139. self.model = OPTModel(config)
  1140. self.qa_outputs = nn.Linear(config.word_embed_proj_dim, 2)
  1141. # Initialize weights and apply final processing
  1142. self.post_init()
  1143. @add_start_docstrings_to_model_forward(OPT_INPUTS_DOCSTRING)
  1144. @replace_return_docstrings(output_type=QuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC)
  1145. def forward(
  1146. self,
  1147. input_ids: Optional[torch.LongTensor] = None,
  1148. attention_mask: Optional[torch.FloatTensor] = None,
  1149. head_mask: Optional[torch.FloatTensor] = None,
  1150. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  1151. inputs_embeds: Optional[torch.FloatTensor] = None,
  1152. start_positions: Optional[torch.LongTensor] = None,
  1153. end_positions: Optional[torch.LongTensor] = None,
  1154. use_cache: Optional[bool] = None,
  1155. output_attentions: Optional[bool] = None,
  1156. output_hidden_states: Optional[bool] = None,
  1157. return_dict: Optional[bool] = None,
  1158. position_ids: Optional[torch.LongTensor] = None,
  1159. ) -> Union[Tuple, QuestionAnsweringModelOutput]:
  1160. r"""
  1161. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1162. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  1163. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1164. are not taken into account for computing the loss.
  1165. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1166. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  1167. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  1168. are not taken into account for computing the loss.
  1169. Returns:
  1170. Example:
  1171. ```python
  1172. >>> from transformers import AutoTokenizer, OPTForQuestionAnswering
  1173. >>> import torch
  1174. >>> torch.manual_seed(4) # doctest: +IGNORE_RESULT
  1175. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
  1176. >>> # note: we are loading a OPTForQuestionAnswering from the hub here,
  1177. >>> # so the head will be randomly initialized, hence the predictions will be random
  1178. >>> model = OPTForQuestionAnswering.from_pretrained("facebook/opt-350m")
  1179. >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
  1180. >>> inputs = tokenizer(question, text, return_tensors="pt")
  1181. >>> with torch.no_grad():
  1182. ... outputs = model(**inputs)
  1183. >>> answer_start_index = outputs.start_logits.argmax()
  1184. >>> answer_end_index = outputs.end_logits.argmax()
  1185. >>> answer_offset = len(tokenizer(question)[0])
  1186. >>> predict_answer_tokens = inputs.input_ids[
  1187. ... 0, answer_offset + answer_start_index : answer_offset + answer_end_index + 1
  1188. ... ]
  1189. >>> predicted = tokenizer.decode(predict_answer_tokens)
  1190. >>> predicted
  1191. ' a nice puppet'
  1192. ```"""
  1193. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1194. transformer_outputs = self.model(
  1195. input_ids,
  1196. past_key_values=past_key_values,
  1197. attention_mask=attention_mask,
  1198. position_ids=position_ids,
  1199. head_mask=head_mask,
  1200. inputs_embeds=inputs_embeds,
  1201. use_cache=use_cache,
  1202. output_attentions=output_attentions,
  1203. output_hidden_states=output_hidden_states,
  1204. return_dict=return_dict,
  1205. )
  1206. hidden_states = transformer_outputs[0]
  1207. logits = self.qa_outputs(hidden_states)
  1208. start_logits, end_logits = logits.split(1, dim=-1)
  1209. start_logits = start_logits.squeeze(-1).contiguous()
  1210. end_logits = end_logits.squeeze(-1).contiguous()
  1211. total_loss = None
  1212. if start_positions is not None and end_positions is not None:
  1213. # If we are on multi-GPU, split add a dimension
  1214. if len(start_positions.size()) > 1:
  1215. start_positions = start_positions.squeeze(-1)
  1216. if len(end_positions.size()) > 1:
  1217. end_positions = end_positions.squeeze(-1)
  1218. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1219. ignored_index = start_logits.size(1)
  1220. start_positions = start_positions.clamp(0, ignored_index).to(logits.device)
  1221. end_positions = end_positions.clamp(0, ignored_index).to(logits.device)
  1222. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1223. start_loss = loss_fct(start_logits, start_positions)
  1224. end_loss = loss_fct(end_logits, end_positions)
  1225. total_loss = (start_loss + end_loss) / 2
  1226. if not return_dict:
  1227. output = (start_logits, end_logits) + transformer_outputs[2:]
  1228. return ((total_loss,) + output) if total_loss is not None else output
  1229. return QuestionAnsweringModelOutput(
  1230. loss=total_loss,
  1231. start_logits=start_logits,
  1232. end_logits=end_logits,
  1233. hidden_states=transformer_outputs.hidden_states,
  1234. attentions=transformer_outputs.attentions,
  1235. )
  1236. def get_input_embeddings(self):
  1237. return self.model.decoder.embed_tokens
  1238. def set_input_embeddings(self, value):
  1239. self.model.decoder.embed_tokens = value