modeling_flash_attention_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # coding=utf-8
  2. # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import inspect
  16. import os
  17. from typing import Optional, Tuple
  18. import torch
  19. import torch.nn.functional as F
  20. from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
  21. if is_flash_attn_2_available():
  22. from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
  23. from flash_attn import flash_attn_func, flash_attn_varlen_func
  24. _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
  25. def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
  26. """
  27. Retrieves indexing data required to repad unpadded (ragged) tensors.
  28. Arguments:
  29. attention_mask (`torch.Tensor`):
  30. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  31. Return:
  32. indices (`torch.Tensor`):
  33. The indices of non-masked tokens from the flattened input sequence.
  34. cu_seqlens (`torch.Tensor`):
  35. The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  36. max_seqlen_in_batch (`int`):
  37. Maximum sequence length in batch.
  38. """
  39. seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
  40. indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
  41. max_seqlen_in_batch = seqlens_in_batch.max().item()
  42. cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
  43. return (
  44. indices,
  45. cu_seqlens,
  46. max_seqlen_in_batch,
  47. )
  48. def _upad_input(
  49. query_layer: torch.Tensor,
  50. key_layer: torch.Tensor,
  51. value_layer: torch.Tensor,
  52. attention_mask: torch.Tensor,
  53. query_length: int,
  54. ):
  55. """
  56. Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
  57. This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
  58. tensors for query, key, value tensors.
  59. Arguments:
  60. query_layer (`torch.Tensor`):
  61. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  62. key_layer (`torch.Tensor`):
  63. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  64. value_layer (`torch.Tensor`):
  65. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  66. attention_mask (`torch.Tensor`):
  67. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  68. query_length (`int`):
  69. Target length.
  70. Return:
  71. query_layer (`torch.Tensor`):
  72. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  73. key_layer (`torch.Tensor`):
  74. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  75. value_layer (`torch.Tensor`):
  76. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  77. indices_q (`torch.Tensor`):
  78. The indices of non-masked tokens from the flattened input target sequence.
  79. (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
  80. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  81. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
  82. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  83. """
  84. indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
  85. batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
  86. key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
  87. value_layer = index_first_axis(
  88. value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
  89. )
  90. if query_length == kv_seq_len:
  91. query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
  92. cu_seqlens_q = cu_seqlens_k
  93. max_seqlen_in_batch_q = max_seqlen_in_batch_k
  94. indices_q = indices_k
  95. elif query_length == 1:
  96. max_seqlen_in_batch_q = 1
  97. cu_seqlens_q = torch.arange(
  98. batch_size + 1, dtype=torch.int32, device=query_layer.device
  99. ) # There is a memcpy here, that is very bad.
  100. indices_q = cu_seqlens_q[:-1]
  101. query_layer = query_layer.squeeze(1)
  102. else:
  103. # The -q_len: slice assumes left padding.
  104. attention_mask = attention_mask[:, -query_length:]
  105. query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
  106. return (
  107. query_layer,
  108. key_layer,
  109. value_layer,
  110. indices_q,
  111. (cu_seqlens_q, cu_seqlens_k),
  112. (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
  113. )
  114. def prepare_fa2_from_position_ids(query, key, value, position_ids):
  115. """
  116. This function returns necessary arguments to call `flash_attn_varlen_func`.
  117. All three query, key, value states will be flattened.
  118. Cummulative lengths of each examples in the batch will be extracted from position_ids.
  119. NOTE: ideally cummulative lengths should be prepared at the data collator stage
  120. Arguments:
  121. query (`torch.Tensor`):
  122. Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
  123. key (`torch.Tensor`):
  124. Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  125. value (`torch.Tensor`):
  126. Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
  127. position_ids (`torch.Tensor`):
  128. Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
  129. Return:
  130. query (`torch.Tensor`):
  131. Query state without padding. Shape: (total_target_length, num_heads, head_dim).
  132. key (`torch.Tensor`):
  133. Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  134. value (`torch.Tensor`):
  135. Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
  136. indices_q (`torch.Tensor`):
  137. The indices of non-masked tokens from the flattened input target sequence.
  138. (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
  139. The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
  140. (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
  141. Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
  142. """
  143. query = query.view(-1, query.size(-2), query.size(-1))
  144. key = key.view(-1, key.size(-2), key.size(-1))
  145. value = value.view(-1, value.size(-2), value.size(-1))
  146. position_ids = position_ids.flatten()
  147. indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
  148. cu_seq_lens = torch.cat(
  149. (
  150. indices_q[position_ids == 0],
  151. torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
  152. )
  153. )
  154. max_length = position_ids.max() + 1
  155. return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
  156. def _flash_attention_forward(
  157. query_states: torch.Tensor,
  158. key_states: torch.Tensor,
  159. value_states: torch.Tensor,
  160. attention_mask: torch.Tensor,
  161. query_length: int,
  162. is_causal: bool,
  163. dropout: float = 0.0,
  164. position_ids: Optional[torch.Tensor] = None,
  165. softmax_scale: Optional[float] = None,
  166. sliding_window: Optional[int] = None,
  167. use_top_left_mask: bool = False,
  168. softcap: Optional[float] = None,
  169. deterministic: bool = None,
  170. ):
  171. """
  172. Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
  173. first unpad the input, then computes the attention scores and pad the final attention scores.
  174. Args:
  175. query_states (`torch.Tensor`):
  176. Input query states to be passed to Flash Attention API
  177. key_states (`torch.Tensor`):
  178. Input key states to be passed to Flash Attention API
  179. value_states (`torch.Tensor`):
  180. Input value states to be passed to Flash Attention API
  181. attention_mask (`torch.Tensor`):
  182. The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
  183. position of padding tokens and 1 for the position of non-padding tokens.
  184. dropout (`float`):
  185. Attention dropout
  186. softmax_scale (`float`, *optional*):
  187. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
  188. use_top_left_mask (`bool`, defaults to `False`):
  189. flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
  190. softcap (`float`, *optional*):
  191. Softcap for the attention logits, used e.g. in gemma2.
  192. deterministic (`bool`, *optional*):
  193. Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
  194. """
  195. if not use_top_left_mask:
  196. causal = is_causal
  197. else:
  198. # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
  199. causal = is_causal and query_length != 1
  200. # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
  201. use_sliding_windows = (
  202. _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
  203. )
  204. flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
  205. if is_flash_attn_greater_or_equal("2.4.1"):
  206. if deterministic is None:
  207. deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
  208. flash_kwargs["deterministic"] = deterministic
  209. if softcap is not None:
  210. flash_kwargs["softcap"] = softcap
  211. # Contains at least one padding token in the sequence
  212. if attention_mask is not None:
  213. batch_size = query_states.shape[0]
  214. query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
  215. query_states, key_states, value_states, attention_mask, query_length
  216. )
  217. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  218. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  219. attn_output_unpad = flash_attn_varlen_func(
  220. query_states,
  221. key_states,
  222. value_states,
  223. cu_seqlens_q=cu_seqlens_q,
  224. cu_seqlens_k=cu_seqlens_k,
  225. max_seqlen_q=max_seqlen_in_batch_q,
  226. max_seqlen_k=max_seqlen_in_batch_k,
  227. dropout_p=dropout,
  228. softmax_scale=softmax_scale,
  229. causal=causal,
  230. **flash_kwargs,
  231. )
  232. attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
  233. # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
  234. # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
  235. # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
  236. # Note: the `torch.diff(...)` condition is last to use short-circuit and avoid the cuda synchronization it incurs during inference (query_length == 1 always)
  237. elif position_ids is not None and query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all():
  238. batch_size = query_states.size(0)
  239. query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
  240. query_states, key_states, value_states, position_ids
  241. )
  242. cu_seqlens_q, cu_seqlens_k = cu_seq_lens
  243. max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
  244. attn_output = flash_attn_varlen_func(
  245. query_states,
  246. key_states,
  247. value_states,
  248. cu_seqlens_q=cu_seqlens_q,
  249. cu_seqlens_k=cu_seqlens_k,
  250. max_seqlen_q=max_seqlen_in_batch_q,
  251. max_seqlen_k=max_seqlen_in_batch_k,
  252. dropout_p=dropout,
  253. softmax_scale=softmax_scale,
  254. causal=causal,
  255. **flash_kwargs,
  256. )
  257. attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
  258. else:
  259. attn_output = flash_attn_func(
  260. query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
  261. )
  262. return attn_output