| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811 |
- # coding=utf-8
- # Copyright 2022 Salesforce authors, The EleutherAI, and HuggingFace Teams. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch CodeGen model."""
- from typing import Optional, Tuple, Union
- import torch
- import torch.utils.checkpoint
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache, StaticCache
- from ...generation import GenerationMixin
- from ...modeling_attn_mask_utils import AttentionMaskConverter
- from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
- from ...modeling_utils import PreTrainedModel
- from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
- from .configuration_codegen import CodeGenConfig
- logger = logging.get_logger(__name__)
- _CHECKPOINT_FOR_DOC = "Salesforce/codegen-2B-mono"
- _CONFIG_FOR_DOC = "CodeGenConfig"
- # Copied from transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
- def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
- sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
- return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
- # Copied from transformers.models.gptj.modeling_gptj.rotate_every_two
- def rotate_every_two(x: torch.Tensor) -> torch.Tensor:
- x1 = x[:, :, :, ::2]
- x2 = x[:, :, :, 1::2]
- x = torch.stack((-x2, x1), dim=-1)
- return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')
- # Copied from transformers.models.gptj.modeling_gptj.apply_rotary_pos_emb
- def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor:
- sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3)
- cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3)
- return (tensor * cos) + (rotate_every_two(tensor) * sin)
- class CodeGenAttention(nn.Module):
- def __init__(self, config, layer_idx=None):
- super().__init__()
- max_positions = config.max_position_embeddings
- self.attn_dropout = nn.Dropout(config.attn_pdrop)
- self.resid_dropout = nn.Dropout(config.resid_pdrop)
- self.layer_idx = layer_idx
- if layer_idx is None:
- logger.warning_once(
- f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
- "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
- self.embed_dim = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- self.head_dim = self.embed_dim // self.num_attention_heads
- if self.head_dim * self.num_attention_heads != self.embed_dim:
- raise ValueError(
- f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
- f" `num_attention_heads`: {self.num_attention_heads})."
- )
- self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
- self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
- self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
- self.rotary_dim = config.rotary_dim
- pos_embd_dim = self.rotary_dim or self.embed_dim
- self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
- def _split_heads(self, x, n_head, dim_head, mp_num):
- reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
- reshaped = reshaped.reshape(x.shape[:-2] + (-1,) + reshaped.shape[-1:])
- return reshaped
- def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
- """
- Merges attn_head_size dim and num_attn_heads dim into n_ctx
- """
- if len(tensor.shape) == 5:
- tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
- elif len(tensor.shape) == 4:
- tensor = tensor.permute(0, 2, 1, 3).contiguous()
- else:
- raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
- new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
- return tensor.view(new_shape)
- def _attn(
- self,
- query,
- key,
- value,
- attention_mask=None,
- head_mask=None,
- ):
- # Keep the attention weights computation in fp32 to avoid overflow issues
- query = query.to(torch.float32)
- key = key.to(torch.float32)
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
- if attention_mask is not None:
- causal_mask = attention_mask[:, :, :, : key.shape[-2]]
- attn_weights += causal_mask
- attn_weights = attn_weights / self.scale_attn
- attn_weights = nn.Softmax(dim=-1)(attn_weights)
- attn_weights = attn_weights.to(value.dtype)
- attn_weights = self.attn_dropout(attn_weights)
- # Mask heads if we want to
- if head_mask is not None:
- attn_weights = attn_weights * head_mask
- attn_output = torch.matmul(attn_weights, value)
- return attn_output, attn_weights
- def forward(
- self,
- hidden_states: Optional[torch.FloatTensor],
- layer_past: Optional[Cache] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = False,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[
- Tuple[torch.Tensor, Tuple[torch.Tensor]],
- Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
- ]:
- qkv = self.qkv_proj(hidden_states)
- # TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
- mp_num = 4
- qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
- local_dim = self.head_dim * self.num_attention_heads // mp_num
- query, value, key = torch.split(qkv_split, local_dim, dim=-1)
- query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
- key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
- value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
- value = value.permute(0, 2, 1, 3)
- embed_positions = self.embed_positions
- if embed_positions.device != position_ids.device:
- embed_positions = embed_positions.to(position_ids.device)
- self.embed_positions = embed_positions
- sincos = embed_positions[position_ids]
- sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
- if self.rotary_dim is not None:
- k_rot = key[:, :, :, : self.rotary_dim]
- k_pass = key[:, :, :, self.rotary_dim :]
- q_rot = query[:, :, :, : self.rotary_dim]
- q_pass = query[:, :, :, self.rotary_dim :]
- k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
- q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
- key = torch.cat([k_rot, k_pass], dim=-1)
- query = torch.cat([q_rot, q_pass], dim=-1)
- else:
- key = apply_rotary_pos_emb(key, sin, cos)
- query = apply_rotary_pos_emb(query, sin, cos)
- key = key.permute(0, 2, 1, 3)
- query = query.permute(0, 2, 1, 3)
- # Note that this cast is quite ugly, but is not implemented before ROPE as k_rot in the original codebase is always in fp32.
- # Reference: https://github.com/salesforce/CodeGen/blob/f210c3bb1216c975ad858cd4132c0fdeabf4bfc2/codegen1/jaxformer/hf/codegen/modeling_codegen.py#L38
- if layer_past is not None:
- cache_kwargs = {
- "sin": sin,
- "cos": cos,
- "partial_rotation_size": self.rotary_dim,
- "cache_position": cache_position,
- }
- key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs)
- # compute self-attention: V x Softmax(QK^T)
- attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
- attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
- attn_output = self.out_proj(attn_output)
- attn_output = self.resid_dropout(attn_output)
- outputs = (attn_output, layer_past)
- if output_attentions:
- outputs += (attn_weights,)
- return outputs # a, present, (attentions)
- # Copied from transformers.models.gptj.modeling_gptj.GPTJMLP with GPTJ->CodeGen
- class CodeGenMLP(nn.Module):
- def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
- super().__init__()
- embed_dim = config.n_embd
- self.fc_in = nn.Linear(embed_dim, intermediate_size)
- self.fc_out = nn.Linear(intermediate_size, embed_dim)
- self.act = ACT2FN[config.activation_function]
- self.dropout = nn.Dropout(config.resid_pdrop)
- def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.FloatTensor:
- hidden_states = self.fc_in(hidden_states)
- hidden_states = self.act(hidden_states)
- hidden_states = self.fc_out(hidden_states)
- hidden_states = self.dropout(hidden_states)
- return hidden_states
- # Copied from transformers.models.gptj.modeling_gptj.GPTJBlock with GPTJ->CodeGen
- class CodeGenBlock(nn.Module):
- # Ignore copy
- def __init__(self, config, layer_idx=None):
- super().__init__()
- inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
- self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
- self.attn = CodeGenAttention(config, layer_idx)
- self.mlp = CodeGenMLP(inner_dim, config)
- def forward(
- self,
- hidden_states: Optional[torch.FloatTensor],
- layer_past: Optional[Cache] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = False,
- output_attentions: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
- residual = hidden_states
- hidden_states = self.ln_1(hidden_states)
- attn_outputs = self.attn(
- hidden_states=hidden_states,
- layer_past=layer_past,
- attention_mask=attention_mask,
- position_ids=position_ids,
- head_mask=head_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
- outputs = attn_outputs[1:]
- feed_forward_hidden_states = self.mlp(hidden_states)
- hidden_states = attn_output + feed_forward_hidden_states + residual
- if use_cache:
- outputs = (hidden_states,) + outputs
- else:
- outputs = (hidden_states,) + outputs[1:]
- return outputs # hidden_states, present, (attentions)
- class CodeGenPreTrainedModel(PreTrainedModel):
- """
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
- models.
- """
- config_class = CodeGenConfig
- base_model_prefix = "transformer"
- supports_gradient_checkpointing = True
- _no_split_modules = ["CodeGenBlock"]
- _skip_keys_device_placement = "past_key_values"
- _supports_cache_class = True
- _supports_quantized_cache = True
- _supports_static_cache = True
- def __init__(self, *inputs, **kwargs):
- super().__init__(*inputs, **kwargs)
- def _init_weights(self, module):
- """Initialize the weights."""
- if isinstance(module, (nn.Linear,)):
- # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- CODEGEN_START_DOCSTRING = r"""
- This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
- it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
- behavior.
- Parameters:
- config ([`CodeGenConfig`]): Model configuration class with all the parameters of the model.
- Initializing with a config file does not load the weights associated with the model, only the
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
- """
- CODEGEN_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `({0})`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using [`AutoProcenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- - 1 for tokens that are **not masked**,
- - 0 for tokens that are **masked**.
- [What are attention masks?](../glossary#attention-mask)
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
- 1]`:
- - 0 corresponds to a *sentence A* token,
- - 1 corresponds to a *sentence B* token.
- [What are token type IDs?](../glossary#token-type-ids)
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
- config.n_positions - 1]`.
- [What are position IDs?](../glossary#position-ids)
- head_mask (`torch.FloatTensor` of shape `(num_attention_heads,)` or `(n_layer, num_attention_heads)`, *optional*):
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- - 1 indicates the head is **not masked**,
- - 0 indicates the head is **masked**.
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_dim)`, *optional*):
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
- is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
- model's internal embedding lookup matrix.
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
- Two formats are allowed:
- - a [`~cache_utils.Cache`] instance, see our
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
- cache format.
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
- legacy cache format will be returned.
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
- of shape `(batch_size, sequence_length)`.
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
- tensors for more detail.
- output_hidden_states (`bool`, *optional*):
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
- more detail.
- return_dict (`bool`, *optional*):
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
- the complete sequence length.
- """
- @add_start_docstrings(
- "The bare CodeGen Model transformer outputting raw hidden-states without any specific head on top.",
- CODEGEN_START_DOCSTRING,
- )
- class CodeGenModel(CodeGenPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- self.embed_dim = config.n_embd
- self.vocab_size = config.vocab_size
- self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
- self.drop = nn.Dropout(config.embd_pdrop)
- self.h = nn.ModuleList([CodeGenBlock(config, layer_idx=i) for i in range(config.n_layer)])
- self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
- self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.wte
- def set_input_embeddings(self, new_embeddings):
- self.wte = new_embeddings
- @add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=BaseModelOutputWithPast,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, BaseModelOutputWithPast]:
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
- if inputs_embeds is None:
- inputs_embeds = self.wte(input_ids)
- # kept for BC (non `Cache` `past_key_values` inputs)
- return_legacy_cache = False
- if use_cache and not isinstance(past_key_values, Cache):
- return_legacy_cache = True
- if past_key_values is None:
- past_key_values = DynamicCache()
- else:
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
- logger.warning_once(
- "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
- "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
- "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
- )
- seq_length = inputs_embeds.shape[1]
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
- causal_mask = self._update_causal_mask(
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
- )
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape bsz x num_attention_heads x N x N
- # head_mask has shape n_layer x batch x num_attention_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
- hidden_states = inputs_embeds
- if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, seq_length)
- token_type_embeds = self.wte(token_type_ids)
- hidden_states = hidden_states + token_type_embeds
- hidden_states = self.drop(hidden_states)
- output_shape = (-1, seq_length, hidden_states.size(-1))
- next_decoder_cache = None
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
- for i, block in enumerate(self.h):
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if self.gradient_checkpointing and self.training:
- outputs = self._gradient_checkpointing_func(
- block.__call__,
- hidden_states,
- None,
- causal_mask,
- position_ids,
- head_mask[i],
- use_cache,
- output_attentions,
- cache_position,
- )
- else:
- outputs = block(
- hidden_states=hidden_states,
- layer_past=past_key_values,
- attention_mask=causal_mask,
- position_ids=position_ids,
- head_mask=head_mask[i],
- use_cache=use_cache,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states = outputs[0]
- if use_cache is True:
- next_decoder_cache = outputs[1]
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
- hidden_states = self.ln_f(hidden_states)
- hidden_states = hidden_states.view(output_shape)
- # Add last hidden state
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- next_cache = next_decoder_cache if use_cache else None
- if return_legacy_cache:
- next_cache = next_cache.to_legacy_cache()
- if not return_dict:
- return tuple(
- v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None
- )
- return BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=next_cache,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
- # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: torch.Tensor,
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and 0.0 in attention_mask:
- return attention_mask
- return None
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_static_cache = isinstance(past_key_values, StaticCache)
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
- dtype, device = input_tensor.dtype, input_tensor.device
- sequence_length = input_tensor.shape[1]
- if using_static_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- device=device,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type == "cuda"
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
- return causal_mask
- @staticmethod
- # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- device: torch.device,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- device (`torch.device`):
- The device to plcae the 4D attention mask on.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
- return causal_mask
- @add_start_docstrings(
- """
- The CodeGen Model transformer with a language modeling head on top.
- """,
- CODEGEN_START_DOCSTRING,
- )
- class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin):
- _tied_weights_keys = ["lm_head.weight"]
- def __init__(self, config):
- super().__init__(config)
- self.transformer = CodeGenModel(config)
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- @add_start_docstrings_to_model_forward(CODEGEN_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
- @add_code_sample_docstrings(
- checkpoint=_CHECKPOINT_FOR_DOC,
- output_type=CausalLMOutputWithPast,
- config_class=_CONFIG_FOR_DOC,
- )
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor]]]] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- ) -> Union[Tuple, CausalLMOutputWithPast]:
- r"""
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.transformer(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- cache_position=cache_position,
- )
- hidden_states = transformer_outputs[0]
- # make sure sampling in fp16 works correctly and
- # compute loss in fp32 to match with mesh-tf version
- # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179
- lm_logits = self.lm_head(hidden_states).to(torch.float32)
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(lm_logits.device)
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- loss = loss.to(hidden_states.dtype)
- if not return_dict:
- output = (lm_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
- return CausalLMOutputWithPast(
- loss=loss,
- logits=lm_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
- @staticmethod
- def _reorder_cache(
- past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
- ) -> Tuple[Tuple[torch.Tensor]]:
- """
- This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or
- [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
- beam_idx at every generation step.
- """
- return tuple(
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
- for layer_past in past_key_values
- )
|