| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184 |
- # coding=utf-8
- # Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
- #
- # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
- # and OPT implementations in this library. It has been modified from its
- # original forms to accommodate minor architectural differences compared
- # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch Persimmon model."""
- import math
- from typing import List, Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, StaticCache
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import AttentionMaskConverter
- from ...modeling_outputs import (
- BaseModelOutputWithPast,
- CausalLMOutputWithPast,
- SequenceClassifierOutputWithPast,
- TokenClassifierOutput,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
- from ...modeling_utils import PreTrainedModel
- from ...utils import (
- add_code_sample_docstrings,
- add_start_docstrings,
- add_start_docstrings_to_model_forward,
- logging,
- replace_return_docstrings,
- )
- from .configuration_persimmon import PersimmonConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base"
- _CONFIG_FOR_DOC = "PersimmonConfig"
- # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon
- class PersimmonRotaryEmbedding(nn.Module):
- def __init__(
- self,
- dim=None,
- max_position_embeddings=2048,
- base=10000,
- device=None,
- scaling_factor=1.0,
- rope_type="default",
- config: Optional[PersimmonConfig] = None,
- ):
- super().__init__()
- # TODO (joao): remove the `if` below, only used for BC
- self.rope_kwargs = {}
- if config is None:
- logger.warning_once(
- "`PersimmonRotaryEmbedding` can now be fully parameterized by passing the model config through the "
- "`config` argument. All other arguments will be removed in v4.46"
- )
- self.rope_kwargs = {
- "rope_type": rope_type,
- "factor": scaling_factor,
- "dim": dim,
- "base": base,
- "max_position_embeddings": max_position_embeddings,
- }
- self.rope_type = rope_type
- self.max_seq_len_cached = max_position_embeddings
- self.original_max_seq_len = max_position_embeddings
- else:
- # BC: "rope_type" was originally "type"
- if config.rope_scaling is not None:
- self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
- else:
- self.rope_type = "default"
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.original_inv_freq = self.inv_freq
- def _dynamic_frequency_update(self, position_ids, device):
- """
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
- 1 - growing beyond the cached sequence length (allow scaling)
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
- """
- seq_len = torch.max(position_ids) + 1
- if seq_len > self.max_seq_len_cached: # growth
- inv_freq, self.attention_scaling = self.rope_init_fn(
- self.config, device, seq_len=seq_len, **self.rope_kwargs
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
- self.max_seq_len_cached = seq_len
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
- self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
- self.max_seq_len_cached = self.original_max_seq_len
- @torch.no_grad()
- def forward(self, x, position_ids):
- if "dynamic" in self.rope_type:
- self._dynamic_frequency_update(position_ids, device=x.device)
- # Core RoPE block
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
- position_ids_expanded = position_ids[:, None, :].float()
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
- device_type = x.device.type
- device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
- with torch.autocast(device_type=device_type, enabled=False):
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos()
- sin = emb.sin()
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
- cos = cos * self.attention_scaling
- sin = sin * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Persimmon
- class PersimmonLinearScalingRotaryEmbedding(PersimmonRotaryEmbedding):
- """PersimmonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
- def __init__(self, *args, **kwargs):
- logger.warning_once(
- "`PersimmonLinearScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
- "`PersimmonRotaryEmbedding`, which now also does linear scaling (simply pass the model config to __init__)."
- )
- kwargs["rope_type"] = "linear"
- super().__init__(*args, **kwargs)
- # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Persimmon
- class PersimmonDynamicNTKScalingRotaryEmbedding(PersimmonRotaryEmbedding):
- """PersimmonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
- def __init__(self, *args, **kwargs):
- logger.warning_once(
- "`PersimmonDynamicNTKScalingRotaryEmbedding` is deprecated an will be removed in v4.46. Please use "
- "`PersimmonRotaryEmbedding`, which now also does dynamic ntk scaling (simply pass the model config to "
- "__init__)."
- )
- kwargs["rope_type"] = "dynamic"
- super().__init__(*args, **kwargs)
- # Copied from transformers.models.llama.modeling_llama.rotate_half
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- position_ids (`torch.Tensor`, *optional*):
- Deprecated and unused.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXMLP with GPTNeoX->Persimmon
- class PersimmonMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
- self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
- self.act = ACT2FN[config.hidden_act]
- def forward(self, hidden_states):
- hidden_states = self.dense_h_to_4h(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.dense_4h_to_h(hidden_states)
- return hidden_states
- class PersimmonAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: PersimmonConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.rope_theta = config.rope_theta
- self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor)
- self.is_causal = True
- if (self.head_dim * self.num_heads) != self.hidden_size:
- raise ValueError(
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
- f" and `num_heads`: {self.num_heads})."
- )
- self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
- self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
- self.qk_layernorm = config.qk_layernorm
- if self.qk_layernorm:
- self.q_layernorm = nn.LayerNorm(
- config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
- )
- self.k_layernorm = nn.LayerNorm(
- config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
- )
- self.attention_dropout = nn.Dropout(config.attention_dropout)
- self.rotary_emb = PersimmonRotaryEmbedding(config=self.config)
- def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
- storage as `fused_qkv`
- Args:
- fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim]
- Returns:
- query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
- value: [batch_size, seq_length, num_heads, head_dim]
- """
- batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
- fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
- return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: bool = False,
- use_cache: bool = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- bsz, q_len, _ = hidden_states.size()
- # [batch_size, seq_length, 3 x hidden_size]
- fused_qkv = self.query_key_value(hidden_states)
- # 3 x [batch_size, seq_length, num_heads, head_dim]
- (query_states, key_states, value_states) = self._split_heads(fused_qkv)
- if self.qk_layernorm:
- query_states = self.q_layernorm(query_states)
- key_states = self.k_layernorm(key_states)
- # [batch_size, num_heads, seq_length, head_dim] -> [batch_size, seq_length, num_heads, head_dim]
- query_states = query_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- if position_embeddings is None:
- logger.warning_once(
- "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
- "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
- "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
- "removed and `position_embeddings` will be mandatory."
- )
- cos, sin = self.rotary_emb(value_states, position_ids)
- else:
- cos, sin = position_embeddings
- # Partial rotary embedding
- query_rot, query_pass = (
- query_states[..., : self.rotary_ndims],
- query_states[..., self.rotary_ndims :],
- )
- key_rot, key_pass = (
- key_states[..., : self.rotary_ndims],
- key_states[..., self.rotary_ndims :],
- )
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
- # [batch_size, seq_length, num_heads, head_dim]
- query_states = torch.cat((query_rot, query_pass), dim=-1)
- key_states = torch.cat((key_rot, key_pass), dim=-1)
- if past_key_value is not None:
- # Specific to RoPE models with partial rotation
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "partial_rotation_size": self.rotary_ndims,
- "cache_position": cache_position,
- }
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights + causal_mask
- # upcast attention to fp32
- attn_weights = nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query_states.dtype)
- attn_weights = self.attention_dropout(attn_weights)
- attn_output = torch.matmul(attn_weights, value_states)
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.dense(attn_output)
- if not output_attentions:
- attn_weights = None
- return attn_output, attn_weights, past_key_value
- class PersimmonDecoderLayer(nn.Module):
- def __init__(self, config: PersimmonConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = PersimmonAttention(config=config, layer_idx=layer_idx)
- self.mlp = PersimmonMLP(config)
- self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout)
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
- `[0, config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- past_key_value (`Tuple(torch.FloatTensor)`, *optional*):
- cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
- (see `past_key_values`).
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence
- position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
- Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
- with `head_dim` being the embedding dimension of each attention head.
- """
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = hidden_states + residual
- outputs = (hidden_states,)
- if output_attentions:
- outputs += (self_attn_weights,)
- if use_cache:
- outputs += (present_key_value,)
- return outputs
- PERSIMMON_START_DOCSTRING = r"""
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
- etc.)
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
- and behavior.
- Parameters:
- config ([`PersimmonConfig`]):
- Model configuration class with all the parameters of the model. Initializing with a config file does not
- load the weights associated with the model, only the configuration. Check out the
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- @add_start_docstrings(
- "The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
- PERSIMMON_START_DOCSTRING,
- )
- class PersimmonPreTrainedModel(PreTrainedModel):
- config_class = PersimmonConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["PersimmonDecoderLayer"]
- _skip_keys_device_placement = "past_key_values"
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
- def _init_weights(self, module):
- std = self.config.initializer_range
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- PERSIMMON_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
- it.
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
- `past_key_values`).
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
- information on the default strategy.
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
- model's internal embedding lookup matrix.
- use_cache (`bool`, *optional*):
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
- `past_key_values`).
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
- """
- @add_start_docstrings(
- "The bare Persimmon Model outputting raw hidden-states without any specific head on top.",
- PERSIMMON_START_DOCSTRING,
- )
- class PersimmonModel(PersimmonPreTrainedModel):
- """
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PersimmonDecoderLayer`]
- Args:
- config: PersimmonConfig
- """
- def __init__(self, config: PersimmonConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
- self.layers = nn.ModuleList(
- [PersimmonDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
- )
- self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.rotary_emb = PersimmonRotaryEmbedding(config=config)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, value):
- self.embed_tokens = value
- @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = False
- if use_cache and not isinstance(past_key_values, Cache):
- return_legacy_cache = True
- if past_key_values is None:
- past_key_values = DynamicCache()
- else:
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- logger.warning_once(
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
- )
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- hidden_states = inputs_embeds
- # create position embeddings to be shared across the decoder layers
- position_embeddings = self.rotary_emb(hidden_states, position_ids)
- # decoder layers
- all_hidden_states = () if output_hidden_states else None
- all_self_attns = () if output_attentions else None
- next_decoder_cache = None
- for decoder_layer in self.layers:
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- if self.gradient_checkpointing and self.training:
- layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
- hidden_states,
- causal_mask,
- position_ids,
- past_key_values,
- output_attentions,
- use_cache,
- cache_position,
- position_embeddings,
- )
- else:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- output_attentions=output_attentions,
- use_cache=use_cache,
- cache_position=cache_position,
- position_embeddings=position_embeddings,
- )
- hidden_states = layer_outputs[0]
- if use_cache:
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
- if output_attentions:
- all_self_attns += (layer_outputs[1],)
- hidden_states = self.final_layernorm(hidden_states)
- # add hidden states from the last decoder layer
- if output_hidden_states:
- all_hidden_states += (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
- if not return_dict:
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attns,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
- return causal_mask
- @staticmethod
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
- class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon
- def __init__(self, config):
- super().__init__(config)
- self.model = PersimmonModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
- def get_input_embeddings(self):
- return self.model.embed_tokens
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
- def get_output_embeddings(self):
- return self.lm_head
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
- def set_decoder(self, decoder):
- self.model = decoder
- # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
- def get_decoder(self):
- return self.model
- @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
- def forward(
- self,
- input_ids: torch.LongTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- num_logits_to_keep: int = 0,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- Args:
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
- num_logits_to_keep (`int`, *optional*):
- Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
- Returns:
- Example:
- ```python
- >>> from transformers import AutoTokenizer, PersimmonForCausalLM
- >>> model = PersimmonForCausalLM.from_pretrained("adept/persimmon-8b-base")
- >>> tokenizer = AutoTokenizer.from_pretrained("adept/persimmon-8b-base")
- >>> prompt = "human: Hey, what should I eat for dinner?"
- >>> inputs = tokenizer(prompt, return_tensors="pt")
- >>> # Generate
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
- 'human: Hey, what should I eat for dinner?\n\ncat: 🐱\n\nhuman: 😐\n\n'
- ```"""
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
- hidden_states = outputs[0]
- # No upscaling to float was ever done for Persimmon
- logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
- shift_labels = shift_labels.view(-1)
- # Enable model parallelism
- shift_labels = shift_labels.to(shift_logits.device)
- loss = loss_fct(shift_logits, shift_labels)
- if not return_dict:
- output = (logits,) + outputs[1:]
- return (loss,) + output if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @add_start_docstrings(
- """
- The Persimmon transformer with a sequence classification head on top (linear layer).
- [`PersimmonForSequenceClassification`] uses the last token in order to do the classification, as other causal
- models (e.g. GPT-2) do.
- Since it does classification on the last token, it requires to know the position of the last token. If a
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
- each row of the batch).
- """,
- PERSIMMON_START_DOCSTRING,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PERSIMMON,Llama->Persimmon
- class PersimmonForSequenceClassification(PersimmonPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = PersimmonModel(config)
- self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.model.embed_tokens
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
- @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
- logits = self.score(hidden_states)
- if input_ids is not None:
- batch_size = input_ids.shape[0]
- else:
- batch_size = inputs_embeds.shape[0]
- if self.config.pad_token_id is None and batch_size != 1:
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
- if self.config.pad_token_id is None:
- sequence_lengths = -1
- else:
- if input_ids is not None:
- # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
- sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
- sequence_lengths = sequence_lengths % input_ids.shape[-1]
- sequence_lengths = sequence_lengths.to(logits.device)
- else:
- sequence_lengths = -1
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
- if not return_dict:
- output = (pooled_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return SequenceClassifierOutputWithPast(
- loss=loss,
- logits=pooled_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @add_start_docstrings(
- """
- The Persimmon Model transformer with a token classification head on top (a linear layer on top of the hidden-states
- output) e.g. for Named-Entity-Recognition (NER) tasks.
- """,
- PERSIMMON_START_DOCSTRING,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Persimmon, LLAMA->PERSIMMON
- class PersimmonForTokenClassification(PersimmonPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.model = PersimmonModel(config)
- if getattr(config, "classifier_dropout", None) is not None:
- classifier_dropout = config.classifier_dropout
- elif getattr(config, "hidden_dropout", None) is not None:
- classifier_dropout = config.hidden_dropout
- else:
- classifier_dropout = 0.1
- self.dropout = nn.Dropout(classifier_dropout)
- self.score = nn.Linear(config.hidden_size, config.num_labels)
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.model.embed_tokens
- def set_input_embeddings(self, value):
- self.model.embed_tokens = value
- @add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=TokenClassifierOutput,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ) -> Union[Tuple, TokenClassifierOutput]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- sequence_output = self.dropout(sequence_output)
- logits = self.score(sequence_output)
- loss = None
- if labels is not None:
- loss = self.loss_function(logits, labels, self.config)
- if not return_dict:
- output = (logits,) + outputs[2:]
- return ((loss,) + output) if loss is not None else output
- return TokenClassifierOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
|