modeling_phi3.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525
  1. # coding=utf-8
  2. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Phi-3 model."""
  16. import math
  17. import warnings
  18. from typing import List, Optional, Tuple, Union
  19. import torch
  20. import torch.utils.checkpoint
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import AttentionMaskConverter
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPast,
  29. CausalLMOutputWithPast,
  30. SequenceClassifierOutputWithPast,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...utils import (
  35. add_code_sample_docstrings,
  36. add_start_docstrings,
  37. add_start_docstrings_to_model_forward,
  38. is_flash_attn_2_available,
  39. is_flash_attn_greater_or_equal_2_10,
  40. logging,
  41. replace_return_docstrings,
  42. )
  43. from .configuration_phi3 import Phi3Config
  44. if is_flash_attn_2_available():
  45. from ...modeling_flash_attention_utils import _flash_attention_forward
  46. logger = logging.get_logger(__name__)
  47. _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
  48. _CONFIG_FOR_DOC = "Phi3Config"
  49. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
  50. class Phi3RMSNorm(nn.Module):
  51. def __init__(self, hidden_size, eps=1e-6):
  52. """
  53. Phi3RMSNorm is equivalent to T5LayerNorm
  54. """
  55. super().__init__()
  56. self.weight = nn.Parameter(torch.ones(hidden_size))
  57. self.variance_epsilon = eps
  58. def forward(self, hidden_states):
  59. input_dtype = hidden_states.dtype
  60. hidden_states = hidden_states.to(torch.float32)
  61. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  62. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  63. return self.weight * hidden_states.to(input_dtype)
  64. def extra_repr(self):
  65. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  66. # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
  67. class Phi3RotaryEmbedding(nn.Module):
  68. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  69. super().__init__()
  70. self.dim = dim
  71. self.max_position_embeddings = max_position_embeddings
  72. self.base = base
  73. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
  74. self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
  75. @torch.no_grad()
  76. def forward(self, x, position_ids, seq_len=None):
  77. # x: [bs, num_attention_heads, seq_len, head_size]
  78. self.inv_freq.to(x.device)
  79. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  80. position_ids_expanded = position_ids[:, None, :].float()
  81. # Force float32 since bfloat16 loses precision on long contexts
  82. # See https://github.com/huggingface/transformers/pull/29285
  83. device_type = x.device.type
  84. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  85. with torch.autocast(device_type=device_type, enabled=False):
  86. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  87. emb = torch.cat((freqs, freqs), dim=-1)
  88. cos = emb.cos()
  89. sin = emb.sin()
  90. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  91. class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
  92. def __init__(self, dim, config, device=None):
  93. warnings.warn(
  94. "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please"
  95. " use Phi3LongRoPEScaledRotaryEmbedding instead.",
  96. FutureWarning,
  97. )
  98. super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
  99. self.short_factor = config.rope_scaling["short_factor"]
  100. self.long_factor = config.rope_scaling["long_factor"]
  101. self.original_max_position_embeddings = config.original_max_position_embeddings
  102. @torch.no_grad()
  103. def forward(self, x, position_ids, seq_len=None):
  104. seq_len = torch.max(position_ids) + 1
  105. if seq_len > self.original_max_position_embeddings:
  106. ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
  107. else:
  108. ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
  109. inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
  110. self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
  111. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  112. position_ids_expanded = position_ids[:, None, :].float()
  113. # Force float32 since bfloat16 loses precision on long contexts
  114. # See https://github.com/huggingface/transformers/pull/29285
  115. device_type = x.device.type
  116. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  117. with torch.autocast(device_type=device_type, enabled=False):
  118. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  119. emb = torch.cat((freqs, freqs), dim=-1)
  120. scale = self.max_position_embeddings / self.original_max_position_embeddings
  121. if scale <= 1.0:
  122. scaling_factor = 1.0
  123. else:
  124. scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
  125. cos = emb.cos() * scaling_factor
  126. sin = emb.sin() * scaling_factor
  127. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  128. class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
  129. def __init__(self, dim, config, device=None):
  130. warnings.warn(
  131. "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers",
  132. FutureWarning,
  133. )
  134. super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
  135. self.short_factor = config.rope_scaling["short_factor"]
  136. self.long_factor = config.rope_scaling["long_factor"]
  137. self.original_max_position_embeddings = config.original_max_position_embeddings
  138. @torch.no_grad()
  139. def forward(self, x, position_ids, seq_len=None):
  140. seq_len = torch.max(position_ids) + 1
  141. if seq_len > self.original_max_position_embeddings:
  142. ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
  143. else:
  144. ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
  145. inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
  146. self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
  147. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  148. position_ids_expanded = position_ids[:, None, :].float()
  149. # Force float32 since bfloat16 loses precision on long contexts
  150. # See https://github.com/huggingface/transformers/pull/29285
  151. device_type = x.device.type
  152. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  153. with torch.autocast(device_type=device_type, enabled=False):
  154. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  155. emb = torch.cat((freqs, freqs), dim=-1)
  156. scale = self.max_position_embeddings / self.original_max_position_embeddings
  157. if scale <= 1.0:
  158. scaling_factor = 1.0
  159. else:
  160. scaling_factor = 0.1 * math.log(scale) + 1.0
  161. cos = emb.cos() * scaling_factor
  162. sin = emb.sin() * scaling_factor
  163. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  164. class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding):
  165. def __init__(self, dim, config, device=None):
  166. super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
  167. self.short_factor = config.rope_scaling["short_factor"]
  168. self.long_factor = config.rope_scaling["long_factor"]
  169. self.original_max_position_embeddings = config.original_max_position_embeddings
  170. @torch.no_grad()
  171. def forward(self, x, position_ids, seq_len=None):
  172. seq_len = seq_len or torch.max(position_ids) + 1
  173. if seq_len > self.original_max_position_embeddings:
  174. ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
  175. else:
  176. ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
  177. inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
  178. self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
  179. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  180. position_ids_expanded = position_ids[:, None, :].float()
  181. # Force float32 since bfloat16 loses precision on long contexts
  182. # See https://github.com/huggingface/transformers/pull/29285
  183. device_type = x.device.type
  184. device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
  185. with torch.autocast(device_type=device_type, enabled=False):
  186. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  187. emb = torch.cat((freqs, freqs), dim=-1)
  188. scale = self.max_position_embeddings / self.original_max_position_embeddings
  189. if scale <= 1.0:
  190. scaling_factor = 1.0
  191. else:
  192. scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
  193. cos = emb.cos() * scaling_factor
  194. sin = emb.sin() * scaling_factor
  195. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  196. # Copied from transformers.models.llama.modeling_llama.rotate_half
  197. def rotate_half(x):
  198. """Rotates half the hidden dims of the input."""
  199. x1 = x[..., : x.shape[-1] // 2]
  200. x2 = x[..., x.shape[-1] // 2 :]
  201. return torch.cat((-x2, x1), dim=-1)
  202. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  203. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  204. """Applies Rotary Position Embedding to the query and key tensors.
  205. Args:
  206. q (`torch.Tensor`): The query tensor.
  207. k (`torch.Tensor`): The key tensor.
  208. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  209. sin (`torch.Tensor`): The sine part of the rotary embedding.
  210. position_ids (`torch.Tensor`, *optional*):
  211. Deprecated and unused.
  212. unsqueeze_dim (`int`, *optional*, defaults to 1):
  213. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  214. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  215. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  216. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  217. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  218. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  219. Returns:
  220. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  221. """
  222. cos = cos.unsqueeze(unsqueeze_dim)
  223. sin = sin.unsqueeze(unsqueeze_dim)
  224. q_embed = (q * cos) + (rotate_half(q) * sin)
  225. k_embed = (k * cos) + (rotate_half(k) * sin)
  226. return q_embed, k_embed
  227. class Phi3MLP(nn.Module):
  228. def __init__(self, config):
  229. super().__init__()
  230. self.config = config
  231. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  232. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  233. self.activation_fn = ACT2FN[config.hidden_act]
  234. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  235. up_states = self.gate_up_proj(hidden_states)
  236. gate, up_states = up_states.chunk(2, dim=-1)
  237. up_states = up_states * self.activation_fn(gate)
  238. return self.down_proj(up_states)
  239. # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
  240. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  241. """
  242. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  243. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  244. """
  245. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  246. if n_rep == 1:
  247. return hidden_states
  248. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  249. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  250. class Phi3Attention(nn.Module):
  251. """Multi-headed attention from 'Attention Is All You Need' paper"""
  252. def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
  253. super().__init__()
  254. self.config = config
  255. self.layer_idx = layer_idx
  256. if layer_idx is None:
  257. logger.warning_once(
  258. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  259. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  260. "when creating this class."
  261. )
  262. self.attention_dropout = config.attention_dropout
  263. self.hidden_size = config.hidden_size
  264. self.num_heads = config.num_attention_heads
  265. self.head_dim = self.hidden_size // self.num_heads
  266. self.num_key_value_heads = config.num_key_value_heads
  267. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  268. self.max_position_embeddings = config.max_position_embeddings
  269. self.original_max_position_embeddings = config.original_max_position_embeddings
  270. self.rope_theta = config.rope_theta
  271. self.rope_scaling = config.rope_scaling
  272. self.is_causal = True
  273. if (self.head_dim * self.num_heads) != self.hidden_size:
  274. raise ValueError(
  275. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  276. f" and `num_heads`: {self.num_heads})."
  277. )
  278. op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
  279. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  280. self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
  281. self._init_rope()
  282. def _init_rope(self):
  283. if self.rope_scaling is None:
  284. self.rotary_emb = Phi3RotaryEmbedding(
  285. self.head_dim,
  286. max_position_embeddings=self.max_position_embeddings,
  287. base=self.rope_theta,
  288. )
  289. else:
  290. scaling_type = self.config.rope_scaling["type"]
  291. if scaling_type == "longrope":
  292. self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.head_dim, self.config)
  293. else:
  294. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  295. def forward(
  296. self,
  297. hidden_states: torch.Tensor,
  298. attention_mask: Optional[torch.Tensor] = None,
  299. position_ids: Optional[torch.LongTensor] = None,
  300. past_key_value: Optional[Cache] = None,
  301. output_attentions: bool = False,
  302. use_cache: bool = False,
  303. cache_position: Optional[torch.LongTensor] = None,
  304. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  305. logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.")
  306. bsz, q_len, _ = hidden_states.size()
  307. qkv = self.qkv_proj(hidden_states)
  308. query_pos = self.num_heads * self.head_dim
  309. query_states = qkv[..., :query_pos]
  310. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  311. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  312. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  313. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  314. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  315. kv_seq_len = key_states.shape[-2]
  316. if past_key_value is not None:
  317. if self.layer_idx is None:
  318. raise ValueError(
  319. f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
  320. "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
  321. "with a layer index."
  322. )
  323. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  324. cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
  325. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  326. if past_key_value is not None:
  327. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  328. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  329. # repeat k/v heads if n_kv_heads < n_heads
  330. key_states = repeat_kv(key_states, self.num_key_value_groups)
  331. value_states = repeat_kv(value_states, self.num_key_value_groups)
  332. attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
  333. if attention_mask is not None:
  334. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  335. attn_weights += causal_mask
  336. # upcast attention to fp32
  337. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
  338. attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
  339. attn_output = torch.matmul(attn_weights, value_states)
  340. if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
  341. raise ValueError(
  342. f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
  343. f" {attn_output.size()}"
  344. )
  345. attn_output = attn_output.transpose(1, 2).contiguous()
  346. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
  347. attn_output = self.o_proj(attn_output)
  348. if not output_attentions:
  349. attn_weights = None
  350. return attn_output, attn_weights, past_key_value
  351. class Phi3FlashAttention2(Phi3Attention):
  352. """
  353. Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
  354. untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
  355. flash attention and deal with padding tokens in case the input contains any of them.
  356. """
  357. # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
  358. def __init__(self, *args, **kwargs):
  359. super().__init__(*args, **kwargs)
  360. # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
  361. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
  362. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
  363. self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
  364. def forward(
  365. self,
  366. hidden_states: torch.Tensor,
  367. attention_mask: Optional[torch.LongTensor] = None,
  368. position_ids: Optional[torch.LongTensor] = None,
  369. past_key_value: Optional[Cache] = None,
  370. output_attentions: bool = False,
  371. use_cache: bool = False,
  372. cache_position: Optional[torch.LongTensor] = None,
  373. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  374. # Phi3FlashAttention2 attention does not support output_attentions
  375. output_attentions = False
  376. bsz, q_len, _ = hidden_states.size()
  377. qkv = self.qkv_proj(hidden_states)
  378. query_pos = self.num_heads * self.head_dim
  379. query_states = qkv[..., :query_pos]
  380. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  381. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  382. # Flash attention requires the input to have the shape
  383. # batch_size x seq_length x head_dim x hidden_dim
  384. # therefore we just need to keep the original shape
  385. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  386. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  387. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  388. kv_seq_len = key_states.shape[-2]
  389. if past_key_value is not None:
  390. if self.layer_idx is None:
  391. raise ValueError(
  392. f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
  393. "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
  394. "with a layer index."
  395. )
  396. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  397. # Because the input can be padded, the absolute sequence length depends on the max position id.
  398. rotary_seq_len = (
  399. max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
  400. )
  401. cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)
  402. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  403. if past_key_value is not None:
  404. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  405. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  406. # repeat k/v heads if n_kv_heads < n_heads
  407. key_states = repeat_kv(key_states, self.num_key_value_groups)
  408. value_states = repeat_kv(value_states, self.num_key_value_groups)
  409. attn_dropout = self.attention_dropout if self.training else 0.0
  410. # In PEFT, usually we cast the layer norms in float32 for training stability reasons
  411. # therefore the input hidden states gets silently casted in float32. Hence, we need
  412. # cast them back in the correct dtype just to be sure everything works as expected.
  413. # This might slowdown training & inference so it is recommended to not cast the LayerNorms
  414. # in fp32.
  415. if query_states.dtype == torch.float32:
  416. if torch.is_autocast_enabled():
  417. target_dtype = torch.get_autocast_gpu_dtype()
  418. # Handle the case where the model is quantized
  419. elif hasattr(self.config, "_pre_quantization_dtype"):
  420. target_dtype = self.config._pre_quantization_dtype
  421. else:
  422. target_dtype = self.qkv_proj.weight.dtype
  423. logger.warning_once(
  424. f"The input hidden states seems to be silently casted in float32, this might be related to"
  425. f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
  426. f" {target_dtype}."
  427. )
  428. query_states = query_states.to(target_dtype)
  429. key_states = key_states.to(target_dtype)
  430. value_states = value_states.to(target_dtype)
  431. # Reashape to the expected shape for Flash Attention
  432. query_states = query_states.transpose(1, 2)
  433. key_states = key_states.transpose(1, 2)
  434. value_states = value_states.transpose(1, 2)
  435. attn_output = _flash_attention_forward(
  436. query_states,
  437. key_states,
  438. value_states,
  439. attention_mask,
  440. q_len,
  441. position_ids=position_ids,
  442. dropout=attn_dropout,
  443. sliding_window=getattr(self.config, "sliding_window", None),
  444. use_top_left_mask=self._flash_attn_uses_top_left_mask,
  445. is_causal=self.is_causal,
  446. )
  447. attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
  448. attn_output = self.o_proj(attn_output)
  449. if not output_attentions:
  450. attn_weights = None
  451. return attn_output, attn_weights, past_key_value
  452. # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
  453. # TODO @Arthur no longer copied from LLama after static cache
  454. class Phi3SdpaAttention(Phi3Attention):
  455. """
  456. Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
  457. `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
  458. SDPA API.
  459. """
  460. # Adapted from Phi3Attention.forward
  461. def forward(
  462. self,
  463. hidden_states: torch.Tensor,
  464. attention_mask: Optional[torch.Tensor] = None,
  465. position_ids: Optional[torch.LongTensor] = None,
  466. past_key_value: Optional[Cache] = None,
  467. output_attentions: bool = False,
  468. use_cache: bool = False,
  469. cache_position: Optional[torch.LongTensor] = None,
  470. ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
  471. if output_attentions:
  472. # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
  473. logger.warning_once(
  474. "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
  475. 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
  476. )
  477. return super().forward(
  478. hidden_states=hidden_states,
  479. attention_mask=attention_mask,
  480. position_ids=position_ids,
  481. past_key_value=past_key_value,
  482. output_attentions=output_attentions,
  483. use_cache=use_cache,
  484. )
  485. bsz, q_len, _ = hidden_states.size()
  486. qkv = self.qkv_proj(hidden_states)
  487. query_pos = self.num_heads * self.head_dim
  488. query_states = qkv[..., :query_pos]
  489. key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
  490. value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
  491. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  492. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  493. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  494. kv_seq_len = key_states.shape[-2]
  495. if past_key_value is not None:
  496. kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
  497. cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
  498. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  499. if past_key_value is not None:
  500. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
  501. key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  502. key_states = repeat_kv(key_states, self.num_key_value_groups)
  503. value_states = repeat_kv(value_states, self.num_key_value_groups)
  504. causal_mask = attention_mask
  505. if attention_mask is not None:
  506. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  507. # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
  508. # Reference: https://github.com/pytorch/pytorch/issues/112577.
  509. if query_states.device.type == "cuda" and attention_mask is not None:
  510. query_states = query_states.contiguous()
  511. key_states = key_states.contiguous()
  512. value_states = value_states.contiguous()
  513. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  514. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  515. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
  516. is_causal = True if causal_mask is None and q_len > 1 else False
  517. attn_output = torch.nn.functional.scaled_dot_product_attention(
  518. query_states,
  519. key_states,
  520. value_states,
  521. attn_mask=causal_mask,
  522. dropout_p=self.attention_dropout if self.training else 0.0,
  523. is_causal=is_causal,
  524. )
  525. attn_output = attn_output.transpose(1, 2).contiguous()
  526. attn_output = attn_output.view(bsz, q_len, self.hidden_size)
  527. attn_output = self.o_proj(attn_output)
  528. return attn_output, None, past_key_value
  529. PHI3_ATTENTION_CLASSES = {
  530. "eager": Phi3Attention,
  531. "flash_attention_2": Phi3FlashAttention2,
  532. "sdpa": Phi3SdpaAttention,
  533. }
  534. class Phi3DecoderLayer(nn.Module):
  535. def __init__(self, config: Phi3Config, layer_idx: int):
  536. super().__init__()
  537. self.config = config
  538. self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
  539. self.mlp = Phi3MLP(config)
  540. self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  541. self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
  542. self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
  543. self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  544. def forward(
  545. self,
  546. hidden_states: torch.Tensor,
  547. attention_mask: Optional[torch.Tensor] = None,
  548. position_ids: Optional[torch.LongTensor] = None,
  549. past_key_value: Optional[Tuple[torch.Tensor]] = None,
  550. output_attentions: Optional[bool] = False,
  551. use_cache: Optional[bool] = False,
  552. cache_position: Optional[torch.LongTensor] = None,
  553. **kwargs,
  554. ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
  555. """
  556. Args:
  557. hidden_states (`torch.FloatTensor`):
  558. input to the layer of shape `(batch, seq_len, embed_dim)`
  559. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
  560. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  561. position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
  562. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  563. `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
  564. output_attentions (`bool`, *optional*):
  565. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  566. returned tensors for more detail.
  567. use_cache (`bool`, *optional*):
  568. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  569. (see `past_key_values`).
  570. past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
  571. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  572. Indices depicting the position of the input sequence tokens in the sequence
  573. kwargs (`dict`, *optional*):
  574. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  575. into the model
  576. """
  577. residual = hidden_states
  578. hidden_states = self.input_layernorm(hidden_states)
  579. # Self Attention
  580. attn_outputs, self_attn_weights, present_key_value = self.self_attn(
  581. hidden_states=hidden_states,
  582. attention_mask=attention_mask,
  583. position_ids=position_ids,
  584. past_key_value=past_key_value,
  585. output_attentions=output_attentions,
  586. use_cache=use_cache,
  587. cache_position=cache_position,
  588. )
  589. hidden_states = residual + self.resid_attn_dropout(attn_outputs)
  590. residual = hidden_states
  591. hidden_states = self.post_attention_layernorm(hidden_states)
  592. hidden_states = self.mlp(hidden_states)
  593. hidden_states = residual + self.resid_mlp_dropout(hidden_states)
  594. outputs = (hidden_states,)
  595. if output_attentions:
  596. outputs += (self_attn_weights,)
  597. if use_cache:
  598. outputs += (present_key_value,)
  599. return outputs
  600. PHI3_START_DOCSTRING = r"""
  601. This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
  602. library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
  603. etc.)
  604. This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
  605. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
  606. and behavior.
  607. Parameters:
  608. config ([`Phi3Config`]):
  609. Model configuration class with all the parameters of the model. Initializing with a config file does not
  610. load the weights associated with the model, only the configuration. Check out the
  611. [`~PreTrainedModel.from_pretrained`] method to load the model weights.
  612. """
  613. @add_start_docstrings(
  614. "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
  615. PHI3_START_DOCSTRING,
  616. )
  617. class Phi3PreTrainedModel(PreTrainedModel):
  618. config_class = Phi3Config
  619. base_model_prefix = "model"
  620. supports_gradient_checkpointing = True
  621. _no_split_modules = ["Phi3DecoderLayer"]
  622. _skip_keys_device_placement = "past_key_values"
  623. _supports_flash_attn_2 = True
  624. _supports_sdpa = True
  625. _supports_cache_class = True
  626. _version = "0.0.5"
  627. def _init_weights(self, module):
  628. std = self.config.initializer_range
  629. if isinstance(module, nn.Linear):
  630. module.weight.data.normal_(mean=0.0, std=std)
  631. if module.bias is not None:
  632. module.bias.data.zero_()
  633. elif isinstance(module, nn.Embedding):
  634. module.weight.data.normal_(mean=0.0, std=std)
  635. if module.padding_idx is not None:
  636. module.weight.data[module.padding_idx].zero_()
  637. PHI3_INPUTS_DOCSTRING = r"""
  638. Args:
  639. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  640. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  641. it.
  642. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  643. [`PreTrainedTokenizer.__call__`] for details.
  644. [What are input IDs?](../glossary#input-ids)
  645. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  646. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  647. - 1 for tokens that are **not masked**,
  648. - 0 for tokens that are **masked**.
  649. [What are attention masks?](../glossary#attention-mask)
  650. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  651. [`PreTrainedTokenizer.__call__`] for details.
  652. If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
  653. `past_key_values`).
  654. If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
  655. and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
  656. information on the default strategy.
  657. - 1 indicates the head is **not masked**,
  658. - 0 indicates the head is **masked**.
  659. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  660. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  661. config.n_positions - 1]`.
  662. [What are position IDs?](../glossary#position-ids)
  663. past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  664. Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  665. blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  666. returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
  667. Two formats are allowed:
  668. - a [`~cache_utils.Cache`] instance, see our
  669. [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
  670. - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
  671. shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
  672. cache format.
  673. The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
  674. legacy cache format will be returned.
  675. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
  676. have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
  677. of shape `(batch_size, sequence_length)`.
  678. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  679. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  680. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  681. model's internal embedding lookup matrix.
  682. use_cache (`bool`, *optional*):
  683. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  684. `past_key_values`).
  685. output_attentions (`bool`, *optional*):
  686. Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  687. tensors for more detail.
  688. output_hidden_states (`bool`, *optional*):
  689. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  690. more detail.
  691. return_dict (`bool`, *optional*):
  692. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  693. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  694. Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
  695. this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
  696. the complete sequence length.
  697. """
  698. @add_start_docstrings(
  699. "The bare Phi-3 model outputting raw hidden-states without any specific head on top.",
  700. PHI3_START_DOCSTRING,
  701. )
  702. class Phi3Model(Phi3PreTrainedModel):
  703. """
  704. Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
  705. Args:
  706. config: Phi3Config
  707. """
  708. def __init__(self, config: Phi3Config):
  709. super().__init__(config)
  710. self.padding_idx = config.pad_token_id
  711. self.vocab_size = config.vocab_size
  712. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  713. self.embed_dropout = nn.Dropout(config.embd_pdrop)
  714. self.layers = nn.ModuleList(
  715. [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  716. )
  717. self._attn_implementation = config._attn_implementation
  718. self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  719. self.gradient_checkpointing = False
  720. # Initialize weights and apply final processing
  721. self.post_init()
  722. def get_input_embeddings(self):
  723. return self.embed_tokens
  724. def set_input_embeddings(self, value):
  725. self.embed_tokens = value
  726. @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  727. def forward(
  728. self,
  729. input_ids: torch.LongTensor = None,
  730. attention_mask: Optional[torch.Tensor] = None,
  731. position_ids: Optional[torch.LongTensor] = None,
  732. past_key_values: Optional[List[torch.FloatTensor]] = None,
  733. inputs_embeds: Optional[torch.FloatTensor] = None,
  734. use_cache: Optional[bool] = None,
  735. output_attentions: Optional[bool] = None,
  736. output_hidden_states: Optional[bool] = None,
  737. return_dict: Optional[bool] = None,
  738. cache_position: Optional[torch.LongTensor] = None,
  739. ) -> Union[Tuple, BaseModelOutputWithPast]:
  740. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  741. output_hidden_states = (
  742. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  743. )
  744. use_cache = use_cache if use_cache is not None else self.config.use_cache
  745. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  746. if (input_ids is None) ^ (inputs_embeds is not None):
  747. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  748. if self.gradient_checkpointing and self.training:
  749. if use_cache:
  750. logger.warning_once(
  751. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  752. )
  753. use_cache = False
  754. # kept for BC (non `Cache` `past_key_values` inputs)
  755. return_legacy_cache = False
  756. if use_cache and not isinstance(past_key_values, Cache):
  757. return_legacy_cache = True
  758. if past_key_values is None:
  759. past_key_values = DynamicCache()
  760. else:
  761. past_key_values = DynamicCache.from_legacy_cache(past_key_values)
  762. logger.warning_once(
  763. "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
  764. "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
  765. "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
  766. )
  767. if inputs_embeds is None:
  768. inputs_embeds = self.embed_tokens(input_ids)
  769. if cache_position is None:
  770. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  771. cache_position = torch.arange(
  772. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  773. )
  774. if position_ids is None:
  775. position_ids = cache_position.unsqueeze(0)
  776. causal_mask = self._update_causal_mask(
  777. attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
  778. )
  779. hidden_states = inputs_embeds
  780. # decoder layers
  781. all_hidden_states = () if output_hidden_states else None
  782. all_self_attns = () if output_attentions else None
  783. next_decoder_cache = None
  784. for decoder_layer in self.layers:
  785. if output_hidden_states:
  786. all_hidden_states += (hidden_states,)
  787. if self.gradient_checkpointing and self.training:
  788. layer_outputs = self._gradient_checkpointing_func(
  789. decoder_layer.__call__,
  790. hidden_states,
  791. causal_mask,
  792. position_ids,
  793. past_key_values,
  794. output_attentions,
  795. use_cache,
  796. cache_position,
  797. )
  798. else:
  799. layer_outputs = decoder_layer(
  800. hidden_states,
  801. attention_mask=causal_mask,
  802. position_ids=position_ids,
  803. past_key_value=past_key_values,
  804. output_attentions=output_attentions,
  805. use_cache=use_cache,
  806. cache_position=cache_position,
  807. )
  808. hidden_states = layer_outputs[0]
  809. if use_cache:
  810. next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  811. if output_attentions:
  812. all_self_attns += (layer_outputs[1],)
  813. hidden_states = self.norm(hidden_states)
  814. # add hidden states from the last decoder layer
  815. if output_hidden_states:
  816. all_hidden_states += (hidden_states,)
  817. next_cache = next_decoder_cache if use_cache else None
  818. if return_legacy_cache:
  819. next_cache = next_cache.to_legacy_cache()
  820. if not return_dict:
  821. return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
  822. return BaseModelOutputWithPast(
  823. last_hidden_state=hidden_states,
  824. past_key_values=next_cache,
  825. hidden_states=all_hidden_states,
  826. attentions=all_self_attns,
  827. )
  828. def _update_causal_mask(
  829. self,
  830. attention_mask: torch.Tensor,
  831. input_tensor: torch.Tensor,
  832. cache_position: torch.Tensor,
  833. past_key_values: Cache,
  834. output_attentions: bool,
  835. ):
  836. if self.config._attn_implementation == "flash_attention_2":
  837. if attention_mask is not None and 0.0 in attention_mask:
  838. return attention_mask
  839. return None
  840. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  841. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  842. # to infer the attention mask.
  843. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  844. using_static_cache = isinstance(past_key_values, StaticCache)
  845. using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
  846. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  847. if (
  848. self.config._attn_implementation == "sdpa"
  849. and not (using_static_cache or using_sliding_window_cache)
  850. and not output_attentions
  851. ):
  852. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  853. attention_mask,
  854. inputs_embeds=input_tensor,
  855. past_key_values_length=past_seen_tokens,
  856. sliding_window=self.config.sliding_window,
  857. is_training=self.training,
  858. ):
  859. return None
  860. dtype, device = input_tensor.dtype, input_tensor.device
  861. min_dtype = torch.finfo(dtype).min
  862. sequence_length = input_tensor.shape[1]
  863. # SlidingWindowCache or StaticCache
  864. if using_sliding_window_cache or using_static_cache:
  865. target_length = past_key_values.get_max_cache_shape()
  866. # DynamicCache or no cache
  867. else:
  868. target_length = (
  869. attention_mask.shape[-1]
  870. if isinstance(attention_mask, torch.Tensor)
  871. else past_seen_tokens + sequence_length + 1
  872. )
  873. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  874. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  875. attention_mask,
  876. sequence_length=sequence_length,
  877. target_length=target_length,
  878. dtype=dtype,
  879. device=device,
  880. cache_position=cache_position,
  881. batch_size=input_tensor.shape[0],
  882. config=self.config,
  883. past_key_values=past_key_values,
  884. )
  885. if (
  886. self.config._attn_implementation == "sdpa"
  887. and attention_mask is not None
  888. and attention_mask.device.type == "cuda"
  889. and not output_attentions
  890. ):
  891. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  892. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  893. # Details: https://github.com/pytorch/pytorch/issues/110213
  894. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  895. return causal_mask
  896. @staticmethod
  897. # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3
  898. def _prepare_4d_causal_attention_mask_with_cache_position(
  899. attention_mask: torch.Tensor,
  900. sequence_length: int,
  901. target_length: int,
  902. dtype: torch.dtype,
  903. device: torch.device,
  904. cache_position: torch.Tensor,
  905. batch_size: int,
  906. config: Phi3Config,
  907. past_key_values: Cache,
  908. ):
  909. """
  910. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  911. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  912. Args:
  913. attention_mask (`torch.Tensor`):
  914. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
  915. sequence_length (`int`):
  916. The sequence length being processed.
  917. target_length (`int`):
  918. The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
  919. dtype (`torch.dtype`):
  920. The dtype to use for the 4D attention mask.
  921. device (`torch.device`):
  922. The device to plcae the 4D attention mask on.
  923. cache_position (`torch.Tensor`):
  924. Indices depicting the position of the input sequence tokens in the sequence.
  925. batch_size (`torch.Tensor`):
  926. Batch size.
  927. config (`Phi3Config`):
  928. The model's configuration class
  929. past_key_values (`Cache`):
  930. The cache class that is being used currently to generate
  931. """
  932. if attention_mask is not None and attention_mask.dim() == 4:
  933. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  934. causal_mask = attention_mask
  935. else:
  936. min_dtype = torch.finfo(dtype).min
  937. causal_mask = torch.full(
  938. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
  939. )
  940. diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
  941. if config.sliding_window is not None:
  942. # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
  943. # the check is needed to verify is current checkpoint was trained with sliding window or not
  944. if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
  945. sliding_attend_mask = torch.arange(target_length, device=device) <= (
  946. cache_position.reshape(-1, 1) - config.sliding_window
  947. )
  948. diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
  949. causal_mask *= diagonal_attend_mask
  950. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  951. if attention_mask is not None:
  952. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  953. if attention_mask.shape[-1] > target_length:
  954. attention_mask = attention_mask[:, :target_length]
  955. mask_length = attention_mask.shape[-1]
  956. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
  957. padding_mask = padding_mask == 0
  958. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  959. padding_mask, min_dtype
  960. )
  961. return causal_mask
  962. class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
  963. _tied_weights_keys = ["lm_head.weight"]
  964. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
  965. def __init__(self, config):
  966. super().__init__(config)
  967. self.model = Phi3Model(config)
  968. self.vocab_size = config.vocab_size
  969. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  970. # Initialize weights and apply final processing
  971. self.post_init()
  972. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
  973. def get_input_embeddings(self):
  974. return self.model.embed_tokens
  975. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
  976. def set_input_embeddings(self, value):
  977. self.model.embed_tokens = value
  978. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
  979. def get_output_embeddings(self):
  980. return self.lm_head
  981. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
  982. def set_output_embeddings(self, new_embeddings):
  983. self.lm_head = new_embeddings
  984. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
  985. def set_decoder(self, decoder):
  986. self.model = decoder
  987. # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
  988. def get_decoder(self):
  989. return self.model
  990. # Ignore copy
  991. @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  992. @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
  993. def forward(
  994. self,
  995. input_ids: torch.LongTensor = None,
  996. attention_mask: Optional[torch.Tensor] = None,
  997. position_ids: Optional[torch.LongTensor] = None,
  998. past_key_values: Optional[List[torch.FloatTensor]] = None,
  999. inputs_embeds: Optional[torch.FloatTensor] = None,
  1000. labels: Optional[torch.LongTensor] = None,
  1001. use_cache: Optional[bool] = None,
  1002. output_attentions: Optional[bool] = None,
  1003. output_hidden_states: Optional[bool] = None,
  1004. return_dict: Optional[bool] = None,
  1005. cache_position: Optional[torch.LongTensor] = None,
  1006. num_logits_to_keep: int = 0,
  1007. **loss_kwargs,
  1008. ) -> Union[Tuple, CausalLMOutputWithPast]:
  1009. r"""
  1010. Args:
  1011. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1012. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1013. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1014. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1015. num_logits_to_keep (`int`, *optional*):
  1016. Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
  1017. `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
  1018. token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
  1019. Returns:
  1020. Example:
  1021. ```python
  1022. >>> from transformers import AutoTokenizer, Phi3ForCausalLM
  1023. >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
  1024. >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
  1025. >>> prompt = "This is an example script ."
  1026. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1027. >>> # Generate
  1028. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  1029. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1030. 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
  1031. ```"""
  1032. if (
  1033. use_cache
  1034. and self.config.rope_scaling
  1035. and cache_position is not None
  1036. and cache_position[0] == self.config.original_max_position_embeddings
  1037. ):
  1038. logger.warning(
  1039. f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed."
  1040. )
  1041. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1042. output_hidden_states = (
  1043. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1044. )
  1045. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1046. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1047. outputs = self.model(
  1048. input_ids=input_ids,
  1049. attention_mask=attention_mask,
  1050. position_ids=position_ids,
  1051. past_key_values=past_key_values,
  1052. inputs_embeds=inputs_embeds,
  1053. use_cache=use_cache,
  1054. output_attentions=output_attentions,
  1055. output_hidden_states=output_hidden_states,
  1056. return_dict=return_dict,
  1057. )
  1058. hidden_states = outputs[0]
  1059. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1060. logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
  1061. loss = None
  1062. if labels is not None:
  1063. loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
  1064. if not return_dict:
  1065. output = (logits,) + outputs[1:]
  1066. return (loss,) + output if loss is not None else output
  1067. return CausalLMOutputWithPast(
  1068. loss=loss,
  1069. logits=logits,
  1070. past_key_values=outputs.past_key_values,
  1071. hidden_states=outputs.hidden_states,
  1072. attentions=outputs.attentions,
  1073. )
  1074. def prepare_inputs_for_generation(
  1075. self,
  1076. input_ids,
  1077. past_key_values=None,
  1078. attention_mask=None,
  1079. inputs_embeds=None,
  1080. cache_position=None,
  1081. position_ids=None,
  1082. use_cache=True,
  1083. num_logits_to_keep=None,
  1084. **kwargs,
  1085. ):
  1086. # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the
  1087. # process
  1088. # When the first time input length reached long and short factor switching point, enforce re-compute cache
  1089. # It will cause downside of slower at this single token position, however, better than current failure.
  1090. if (
  1091. past_key_values
  1092. and self.config.rope_scaling
  1093. and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1
  1094. ):
  1095. past_length = cache_position[0]
  1096. if past_length <= self.config.original_max_position_embeddings:
  1097. past_key_values = None
  1098. model_inputs = super().prepare_inputs_for_generation(
  1099. input_ids=input_ids,
  1100. past_key_values=past_key_values,
  1101. attention_mask=attention_mask,
  1102. inputs_embeds=inputs_embeds,
  1103. cache_position=cache_position,
  1104. position_ids=position_ids,
  1105. use_cache=use_cache,
  1106. num_logits_to_keep=num_logits_to_keep,
  1107. **kwargs,
  1108. )
  1109. return model_inputs
  1110. @add_start_docstrings(
  1111. """
  1112. The [`Phi3Model`] with a sequence classification head on top (linear layer).
  1113. [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
  1114. (e.g. GPT-2) do.
  1115. Since it does classification on the last token, it requires to know the position of the last token. If a
  1116. `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
  1117. no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
  1118. padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
  1119. each row of the batch).
  1120. """,
  1121. PHI3_START_DOCSTRING,
  1122. )
  1123. # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
  1124. class Phi3ForSequenceClassification(Phi3PreTrainedModel):
  1125. def __init__(self, config):
  1126. super().__init__(config)
  1127. self.num_labels = config.num_labels
  1128. self.model = Phi3Model(config)
  1129. self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
  1130. # Initialize weights and apply final processing
  1131. self.post_init()
  1132. def get_input_embeddings(self):
  1133. return self.model.embed_tokens
  1134. def set_input_embeddings(self, value):
  1135. self.model.embed_tokens = value
  1136. @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  1137. def forward(
  1138. self,
  1139. input_ids: Optional[torch.LongTensor] = None,
  1140. attention_mask: Optional[torch.Tensor] = None,
  1141. position_ids: Optional[torch.LongTensor] = None,
  1142. past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
  1143. inputs_embeds: Optional[torch.FloatTensor] = None,
  1144. labels: Optional[torch.LongTensor] = None,
  1145. use_cache: Optional[bool] = None,
  1146. output_attentions: Optional[bool] = None,
  1147. output_hidden_states: Optional[bool] = None,
  1148. return_dict: Optional[bool] = None,
  1149. ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
  1150. r"""
  1151. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1152. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1153. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1154. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1155. """
  1156. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1157. model_outputs = self.model(
  1158. input_ids,
  1159. attention_mask=attention_mask,
  1160. position_ids=position_ids,
  1161. past_key_values=past_key_values,
  1162. inputs_embeds=inputs_embeds,
  1163. use_cache=use_cache,
  1164. output_attentions=output_attentions,
  1165. output_hidden_states=output_hidden_states,
  1166. return_dict=return_dict,
  1167. )
  1168. hidden_states = model_outputs[0]
  1169. logits = self.score(hidden_states)
  1170. if input_ids is not None:
  1171. batch_size = input_ids.shape[0]
  1172. else:
  1173. batch_size = inputs_embeds.shape[0]
  1174. if self.config.pad_token_id is None and batch_size != 1:
  1175. raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
  1176. if self.config.pad_token_id is None:
  1177. sequence_lengths = -1
  1178. else:
  1179. if input_ids is not None:
  1180. # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
  1181. sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
  1182. sequence_lengths = sequence_lengths % input_ids.shape[-1]
  1183. sequence_lengths = sequence_lengths.to(logits.device)
  1184. else:
  1185. sequence_lengths = -1
  1186. pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
  1187. loss = None
  1188. if labels is not None:
  1189. loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
  1190. if not return_dict:
  1191. output = (pooled_logits,) + model_outputs[1:]
  1192. return ((loss,) + output) if loss is not None else output
  1193. return SequenceClassifierOutputWithPast(
  1194. loss=loss,
  1195. logits=pooled_logits,
  1196. past_key_values=model_outputs.past_key_values,
  1197. hidden_states=model_outputs.hidden_states,
  1198. attentions=model_outputs.attentions,
  1199. )
  1200. @add_start_docstrings(
  1201. """
  1202. [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
  1203. Named-Entity-Recognition (NER) tasks.
  1204. """,
  1205. PHI3_START_DOCSTRING,
  1206. )
  1207. # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
  1208. class Phi3ForTokenClassification(Phi3PreTrainedModel):
  1209. def __init__(self, config: Phi3Config):
  1210. super().__init__(config)
  1211. self.num_labels = config.num_labels
  1212. self.model = Phi3Model(config)
  1213. if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
  1214. classifier_dropout = config.classifier_dropout
  1215. elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
  1216. classifier_dropout = config.hidden_dropout
  1217. else:
  1218. classifier_dropout = 0.1
  1219. self.dropout = nn.Dropout(classifier_dropout)
  1220. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1221. # Initialize weights and apply final processing
  1222. self.post_init()
  1223. @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
  1224. @add_code_sample_docstrings(
  1225. checkpoint=_CHECKPOINT_FOR_DOC,
  1226. output_type=TokenClassifierOutput,
  1227. config_class=_CONFIG_FOR_DOC,
  1228. )
  1229. def forward(
  1230. self,
  1231. input_ids: Optional[torch.LongTensor] = None,
  1232. past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
  1233. attention_mask: Optional[torch.Tensor] = None,
  1234. inputs_embeds: Optional[torch.Tensor] = None,
  1235. labels: Optional[torch.Tensor] = None,
  1236. use_cache: Optional[bool] = None,
  1237. output_attentions: Optional[bool] = None,
  1238. output_hidden_states: Optional[bool] = None,
  1239. return_dict: Optional[bool] = None,
  1240. **deprecated_arguments,
  1241. ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
  1242. r"""
  1243. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1244. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1245. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1246. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1247. """
  1248. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1249. model_outputs = self.model(
  1250. input_ids,
  1251. past_key_values=past_key_values,
  1252. attention_mask=attention_mask,
  1253. inputs_embeds=inputs_embeds,
  1254. use_cache=use_cache,
  1255. output_attentions=output_attentions,
  1256. output_hidden_states=output_hidden_states,
  1257. return_dict=return_dict,
  1258. )
  1259. hidden_states = model_outputs[0]
  1260. hidden_states = self.dropout(hidden_states)
  1261. logits = self.classifier(hidden_states)
  1262. loss = None
  1263. if labels is not None:
  1264. # move labels to correct device to enable model parallelism
  1265. labels = labels.to(logits.device)
  1266. batch_size, seq_length = labels.shape
  1267. loss_fct = CrossEntropyLoss()
  1268. loss = loss_fct(
  1269. logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
  1270. )
  1271. if not return_dict:
  1272. output = (logits,) + model_outputs[2:]
  1273. return ((loss,) + output) if loss is not None else output
  1274. return TokenClassifierOutput(
  1275. loss=loss,
  1276. logits=logits,
  1277. hidden_states=model_outputs.hidden_states,
  1278. attentions=model_outputs.attentions,
  1279. )