| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301 |
- # coding=utf-8
- # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import inspect
- import os
- from typing import Optional, Tuple
- import torch
- import torch.nn.functional as F
- from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
- if is_flash_attn_2_available():
- from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
- from flash_attn import flash_attn_func, flash_attn_varlen_func
- _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
- def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
- """
- Retrieves indexing data required to repad unpadded (ragged) tensors.
- Arguments:
- attention_mask (`torch.Tensor`):
- Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
- Return:
- indices (`torch.Tensor`):
- The indices of non-masked tokens from the flattened input sequence.
- cu_seqlens (`torch.Tensor`):
- The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
- max_seqlen_in_batch (`int`):
- Maximum sequence length in batch.
- """
- seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
- indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
- max_seqlen_in_batch = seqlens_in_batch.max().item()
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
- return (
- indices,
- cu_seqlens,
- max_seqlen_in_batch,
- )
- def _upad_input(
- query_layer: torch.Tensor,
- key_layer: torch.Tensor,
- value_layer: torch.Tensor,
- attention_mask: torch.Tensor,
- query_length: int,
- ):
- """
- Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
- This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
- tensors for query, key, value tensors.
- Arguments:
- query_layer (`torch.Tensor`):
- Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
- key_layer (`torch.Tensor`):
- Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
- value_layer (`torch.Tensor`):
- Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
- attention_mask (`torch.Tensor`):
- Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
- query_length (`int`):
- Target length.
- Return:
- query_layer (`torch.Tensor`):
- Query state without padding. Shape: (total_target_length, num_heads, head_dim).
- key_layer (`torch.Tensor`):
- Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
- value_layer (`torch.Tensor`):
- Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
- indices_q (`torch.Tensor`):
- The indices of non-masked tokens from the flattened input target sequence.
- (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
- The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
- Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
- """
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
- key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
- value_layer = index_first_axis(
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
- )
- if query_length == kv_seq_len:
- query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
- cu_seqlens_q = cu_seqlens_k
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
- indices_q = indices_k
- elif query_length == 1:
- max_seqlen_in_batch_q = 1
- cu_seqlens_q = torch.arange(
- batch_size + 1, dtype=torch.int32, device=query_layer.device
- ) # There is a memcpy here, that is very bad.
- indices_q = cu_seqlens_q[:-1]
- query_layer = query_layer.squeeze(1)
- else:
- # The -q_len: slice assumes left padding.
- attention_mask = attention_mask[:, -query_length:]
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
- return (
- query_layer,
- key_layer,
- value_layer,
- indices_q,
- (cu_seqlens_q, cu_seqlens_k),
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
- )
- def prepare_fa2_from_position_ids(query, key, value, position_ids):
- """
- This function returns necessary arguments to call `flash_attn_varlen_func`.
- All three query, key, value states will be flattened.
- Cummulative lengths of each examples in the batch will be extracted from position_ids.
- NOTE: ideally cummulative lengths should be prepared at the data collator stage
- Arguments:
- query (`torch.Tensor`):
- Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
- key (`torch.Tensor`):
- Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
- value (`torch.Tensor`):
- Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
- position_ids (`torch.Tensor`):
- Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
- Return:
- query (`torch.Tensor`):
- Query state without padding. Shape: (total_target_length, num_heads, head_dim).
- key (`torch.Tensor`):
- Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
- value (`torch.Tensor`):
- Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
- indices_q (`torch.Tensor`):
- The indices of non-masked tokens from the flattened input target sequence.
- (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
- The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
- Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
- """
- query = query.view(-1, query.size(-2), query.size(-1))
- key = key.view(-1, key.size(-2), key.size(-1))
- value = value.view(-1, value.size(-2), value.size(-1))
- position_ids = position_ids.flatten()
- indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
- cu_seq_lens = torch.cat(
- (
- indices_q[position_ids == 0],
- torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
- )
- )
- max_length = position_ids.max() + 1
- return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
- def _flash_attention_forward(
- query_states: torch.Tensor,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- attention_mask: torch.Tensor,
- query_length: int,
- is_causal: bool,
- dropout: float = 0.0,
- position_ids: Optional[torch.Tensor] = None,
- softmax_scale: Optional[float] = None,
- sliding_window: Optional[int] = None,
- use_top_left_mask: bool = False,
- softcap: Optional[float] = None,
- deterministic: bool = None,
- ):
- """
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
- first unpad the input, then computes the attention scores and pad the final attention scores.
- Args:
- query_states (`torch.Tensor`):
- Input query states to be passed to Flash Attention API
- key_states (`torch.Tensor`):
- Input key states to be passed to Flash Attention API
- value_states (`torch.Tensor`):
- Input value states to be passed to Flash Attention API
- attention_mask (`torch.Tensor`):
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
- position of padding tokens and 1 for the position of non-padding tokens.
- dropout (`float`):
- Attention dropout
- softmax_scale (`float`, *optional*):
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
- use_top_left_mask (`bool`, defaults to `False`):
- flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
- softcap (`float`, *optional*):
- Softcap for the attention logits, used e.g. in gemma2.
- deterministic (`bool`, *optional*):
- Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
- """
- if not use_top_left_mask:
- causal = is_causal
- else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
- causal = is_causal and query_length != 1
- # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
- use_sliding_windows = (
- _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
- )
- flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
- if is_flash_attn_greater_or_equal("2.4.1"):
- if deterministic is None:
- deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
- flash_kwargs["deterministic"] = deterministic
- if softcap is not None:
- flash_kwargs["softcap"] = softcap
- # Contains at least one padding token in the sequence
- if attention_mask is not None:
- batch_size = query_states.shape[0]
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
- query_states, key_states, value_states, attention_mask, query_length
- )
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output_unpad = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- **flash_kwargs,
- )
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
- # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
- # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
- # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
- # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
- elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
- batch_size = query_states.size(0)
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
- query_states, key_states, value_states, position_ids
- )
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output = flash_attn_varlen_func(
- query_states,
- key_states,
- value_states,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_k=cu_seqlens_k,
- max_seqlen_q=max_seqlen_in_batch_q,
- max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
- softmax_scale=softmax_scale,
- causal=causal,
- **flash_kwargs,
- )
- attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
- else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
- )
- return attn_output
|