| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776 |
- # mypy: allow-untyped-defs
- import logging
- from typing import Optional, Tuple
- import torch
- import torch.nn
- import torch.nn.functional as F
- from torch.backends.cuda import (
- can_use_efficient_attention,
- can_use_flash_attention,
- flash_sdp_enabled,
- math_sdp_enabled,
- mem_efficient_sdp_enabled,
- SDPAParams,
- )
- from torch.nn.attention import SDPBackend
- from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
- log = logging.getLogger(__name__)
- def _validate_sdpa_input(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p=0.0,
- is_causal=False,
- scale=None,
- ):
- if (
- not isinstance(query, NestedTensor)
- or not isinstance(key, NestedTensor)
- or not isinstance(value, NestedTensor)
- ):
- raise ValueError(
- f"Expected query, key, and value to be nested tensors, "
- f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
- f"and value.is_nested: {value.is_nested} instead."
- )
- if query.dtype != key.dtype or query.dtype != value.dtype:
- raise ValueError(
- f"Expected query, key, and value to have the same dtype, "
- f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
- f"and value.dtype: {value.dtype} instead."
- )
- if query.device != key.device or query.device != value.device:
- raise ValueError(
- f"Expected query, key, and value to have the same device type, "
- f"but got query.device: {query.device}, key.device: {key.device}, "
- f"and value.device: {value.device} instead."
- )
- if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
- raise ValueError(
- f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
- f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
- )
- if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
- raise ValueError(
- f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
- f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
- )
- if attn_mask is not None:
- # TODO: Figure out whether masks are actually supported for this layout or not
- raise ValueError("Masks are not yet supported!")
- if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
- raise ValueError(
- f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
- f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
- )
- def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
- # This is expected to be called after check_tensor_shapes ensuring that the
- # size() calls won't error since the inputs are all 4 dimensional
- q_batch_size = params.query.size(0)
- k_batch_size = params.key.size(0)
- v_batch_size = params.value.size(0)
- # num_heads logic for nested input is checked in
- # check_for_seq_len_0_nested_tensor as there is handling there to make sure
- # num_heads is not ragged
- return q_batch_size == k_batch_size and q_batch_size == v_batch_size
- def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
- max_size = 256
- query_size_last = params.query.size(-1)
- key_size_last = params.key.size(-1)
- value_size_last = params.value.size(-1)
- same_head_dim_size = (
- query_size_last == key_size_last and query_size_last == value_size_last
- )
- if not (
- same_head_dim_size
- and (query_size_last % 8 == 0)
- and (query_size_last <= max_size)
- ):
- if debug:
- log.warning(
- "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
- "last dimension and to be a multiple of 8 and less than or equal to 256. "
- "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
- query_size_last,
- key_size_last,
- value_size_last,
- )
- return False
- return True
- def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
- param: torch.Tensor, param_name: str, debug=False
- ) -> bool:
- assert isinstance(param, NestedTensor), "param should be a jagged NT"
- if param._ragged_idx == 1:
- # num_head_dims is ragged
- if debug:
- log.warning(
- "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
- param_name,
- )
- return False
- # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
- if param._min_seqlen == 0:
- if debug:
- log.warning(
- "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
- param_name,
- )
- return False
- return True
- def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
- max_size = max(q_size, k_size, v_size)
- if (
- (q_size != max_size and q_size != 1)
- or (k_size != max_size and k_size != 1)
- or (v_size != max_size and v_size != 1)
- ):
- if debug:
- log.warning(
- "Both fused kernels require query, key and value to have broadcastable %s, "
- "got Query %s %d, Key %s %d, Value %s %d instead.",
- param_name,
- param_name,
- q_size,
- param_name,
- k_size,
- param_name,
- v_size,
- )
- return False
- return True
- def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
- # When this function is called we are assured that the nt is dim==4
- q_is_safe = (
- _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
- params.query, "query", debug
- )
- if params.query.is_nested
- else True
- )
- # short circuit if any is unsafe
- if not q_is_safe:
- return False
- k_is_safe = (
- _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
- params.key, "key", debug
- )
- if params.key.is_nested
- else True
- )
- # short circuit if any is unsafe
- if not k_is_safe:
- return False
- v_is_safe = (
- _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
- params.value, "value", debug
- )
- if params.value.is_nested
- else True
- )
- # short circuit if any is unsafe
- if not v_is_safe:
- return False
- # We now know none of the inputs have ragged num_heads, so we can safely
- # access .size(1)
- q_num_heads = params.query.size(1)
- k_num_heads = params.key.size(1)
- v_num_heads = params.value.size(1)
- same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
- if not same_num_heads:
- if (
- params.query.requires_grad
- or params.key.requires_grad
- or params.value.requires_grad
- ):
- if debug:
- log.warning(
- "Both fused kernels do not support training with broadcasted NT inputs."
- )
- return False
- return _try_broadcast_param_size(
- q_num_heads, k_num_heads, v_num_heads, "num heads", debug
- )
- return True
- def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
- constraints = (
- _check_batch_size_nested,
- _check_head_dim_size_flash_nested,
- _check_for_seq_len_0_nested,
- )
- for constraint in constraints:
- if not constraint(params, debug):
- return False
- return True
- def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
- constraints = (
- _check_batch_size_nested,
- _check_for_seq_len_0_nested,
- )
- for constraint in constraints:
- if not constraint(params, debug):
- return False
- return True
- def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
- if (
- not params.query.transpose(1, 2).is_contiguous()
- or not params.key.transpose(1, 2).is_contiguous()
- or not params.value.transpose(1, 2).is_contiguous()
- ):
- if debug:
- log.warning(
- "If inputs are nested tensors they must be contiguous after transposing."
- )
- return False
- if params.is_causal:
- if debug:
- log.warning(
- "Nested tensors for query / key are not supported when is_causal=True."
- )
- return False
- return True
- def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
- if (
- not flash_sdp_enabled()
- and not mem_efficient_sdp_enabled()
- and not math_sdp_enabled()
- ):
- return SDPBackend.ERROR
- ordering = (
- SDPBackend.FLASH_ATTENTION,
- SDPBackend.EFFICIENT_ATTENTION,
- SDPBackend.MATH,
- )
- params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
- for backend in ordering:
- if backend == SDPBackend.FLASH_ATTENTION:
- if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
- return SDPBackend.FLASH_ATTENTION
- if backend == SDPBackend.EFFICIENT_ATTENTION:
- if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
- params
- ):
- return SDPBackend.EFFICIENT_ATTENTION
- if backend == SDPBackend.MATH:
- if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
- return SDPBackend.MATH
- log.warning("Memory efficient kernel not used because:")
- can_use_efficient_attention(params, debug=True)
- _can_use_efficient_sdpa_jagged(params, debug=True)
- log.warning("Flash attention kernel not used because:")
- can_use_flash_attention(params, debug=True)
- _can_use_flash_sdpa_jagged(params, debug=True)
- log.warning("Math attention kernel not used because:")
- _can_use_math_sdpa_jagged(params, debug=True)
- return SDPBackend.ERROR
- def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
- # This function is used to calculate two pieces of metadata that are needed
- # for use with flash-attention and efficient_attention kernels. They are the
- # cumulative sequence_length over a batch of sequences and the maximum
- # sequence length.
- # It returns a tuple of cumulative sequence lengths and the maximum sequence
- # length, and the last element in the cumulative_sequence_lengths
- if not isinstance(qkv, NestedTensor):
- raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
- if qkv.lengths() is None:
- # TODO: Explore performance impact of copying
- cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
- max_seqlen = qkv._max_seqlen
- n_elem = qkv.values().shape[0]
- else:
- # TODO: Explore performance impact of copying
- cumulative_seqlen = (
- qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
- )
- batch_size = qkv.size(0)
- max_seqlen = qkv._max_seqlen
- # TODO: Explore performance impact when compiling
- n_elem = int(cumulative_seqlen[-1].item())
- return cumulative_seqlen, max_seqlen, n_elem
- def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
- # This function checks if a nested tensor is valid for
- # use with the flash-attention and efficient_attention kernels without
- # needing to call contiguous on the nested tensor input.
- # It checks that the storage offsets' adjacent_differences are a constant
- # mutiple of the previous tensor in the nested tensor and that the strides
- # are monitonically decreasing. This check is done after calling transpose on
- # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
- # Returns a boolean indicating if contiguous needs to be called for input
- assert isinstance(tensor, NestedTensor)
- offsets = tensor.offsets()
- strides = tensor._strides
- n_tensors = offsets.size(0) - 1
- if n_tensors <= 1:
- return True
- # Check initially that the tensor strides are in strictly descending order
- prev_stride = strides[1]
- for stride in strides[2:]:
- if prev_stride <= stride:
- # This would mean that the last stride is greater than the seq_len
- # stride
- return False
- prev_stride = stride
- # Congrats you made it!
- return True
- def _view_as_dense(
- tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
- ) -> torch.Tensor:
- if tensor.is_nested:
- return buffer_from_jagged(tensor)
- return tensor.view(Nnz, num_heads, head_dim)
- # TODO: Next iteration should add test cases and check it works
- # def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
- # # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
- # # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
- # # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
- # q_batch_size = query.size(0)
- # k_batch_size = key.size(0)
- # v_batch_size = value.size(0)
- # output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
- # q_num_heads = query.size(1)
- # k_num_heads = key.size(1)
- # v_num_heads = value.size(1)
- # output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
- # head_dim_qk = query.size(3)
- # head_dim_v = value.size(3)
- # q_t = query.transpose(1, 2)
- # k_t = key.transpose(1, 2)
- # v_t = value.transpose(1, 2)
- # # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
- # # output_batch_size/num_heads then they are 1
- # q_batch_size_needs_broadcast = q_batch_size != output_batch_size
- # k_batch_size_needs_broadcast = k_batch_size != output_batch_size
- # v_batch_size_needs_broadcast = v_batch_size != output_batch_size
- # # If {*}_batch_size_needs_broadcast, then
- # # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
- # # this is because needs_broadcast indicates that the batch_size is 1
- # # and hence there is only 1 value for seq_len
- # # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
- # # ..., outut_batch_size * {*}_t.size(1)]
- # # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
- # if q_batch_size_needs_broadcast or not q_t.is_nested:
- # max_seqlen_batch_q = q_t.size(1)
- # cumulative_sequence_length_q = torch.arange(
- # 0,
- # (output_batch_size + 1) * max_seqlen_batch_q,
- # max_seqlen_batch_q,
- # device=q_t.device,
- # dtype=torch.int32,
- # )
- # Nnz_q = output_batch_size * max_seqlen_batch_q
- # else:
- # (
- # cumulative_sequence_length_q,
- # max_seqlen_batch_q,
- # Nnz_q,
- # ) = _cumulative_and_max_seq_len_nnz(q_t)
- # if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
- # assert k_t.size(1) == v_t.size(1)
- # max_seqlen_batch_kv = k_t.size(1)
- # cumulative_sequence_length_kv = torch.arange(
- # 0,
- # (output_batch_size + 1) * max_seqlen_batch_kv,
- # max_seqlen_batch_kv,
- # device=k_t.device,
- # dtype=torch.int32,
- # )
- # Nnz_kv = output_batch_size * max_seqlen_batch_kv
- # else:
- # cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
- # _cumulative_and_max_seq_len_nnz(v_t)
- # if k_batch_size_needs_broadcast
- # else _cumulative_and_max_seq_len_nnz(k_t)
- # )
- # q_num_heads_needs_broadcast = q_num_heads != output_num_heads
- # k_num_heads_needs_broadcast = k_num_heads != output_num_heads
- # v_num_heads_needs_broadcast = v_num_heads != output_num_heads
- # if not q_t.is_nested:
- # query_buffer_reshaped = q_t.expand(
- # output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
- # )
- # query_buffer_reshaped = query_buffer_reshaped.reshape(
- # Nnz_q, output_num_heads, head_dim_qk
- # )
- # else:
- # if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
- # q_t = q_t.contiguous()
- # # If we are broadcasting then Nnz_q will be the output_batch_size since
- # # seq_len is 1
- # effective_batch_size_q = (
- # output_batch_size if q_batch_size_needs_broadcast else Nnz_q
- # )
- # query_buffer_reshaped = _view_as_dense(
- # q_t, effective_batch_size_q, output_num_heads, head_dim_qk
- # )
- # # If the physical layout of the NestedTensor's storage
- # # is not: batch, {seq_len}, num_heads, head_dim then we need
- # # to call contiguous
- # if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
- # k_t = k_t.contiguous()
- # if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
- # v_t = v_t.contiguous()
- # effective_batch_size_k = (
- # output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
- # )
- # key_buffer_reshaped = _view_as_dense(
- # k_t, effective_batch_size_k, output_num_heads, head_dim_qk
- # )
- # effective_batch_size_v = (
- # output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
- # )
- # value_buffer_reshaped = _view_as_dense(
- # v_t, effective_batch_size_v, output_num_heads, head_dim_v
- # )
- # if not q_batch_size_needs_broadcast:
- # output_shape = q_t._size
- # if head_dim_v != head_dim_qk:
- # output_shape[-1] = head_dim_v
- # if q_num_heads_needs_broadcast:
- # output_shape[1] = output_num_heads
- # else:
- # output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
- # output_shape[0] = q_t.size(1)
- # output_shape[1] = output_num_heads
- # output_shape[2] = head_dim_v
- # return (
- # query_buffer_reshaped,
- # key_buffer_reshaped,
- # value_buffer_reshaped,
- # cumulative_sequence_length_q,
- # cumulative_sequence_length_kv,
- # max_seqlen_batch_q,
- # max_seqlen_batch_kv,
- # output_shape,
- # )
- def _sdpa_nested_preprocessing(query, key, value):
- # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
- # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
- # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
- q_batch_size = query.size(0)
- k_batch_size = key.size(0)
- v_batch_size = value.size(0)
- q_num_heads = query.size(1)
- k_num_heads = key.size(1)
- v_num_heads = value.size(1)
- if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
- q_num_heads == k_num_heads and k_num_heads == v_num_heads
- ):
- raise RuntimeError(
- "This path is currently not implemented for jagged layout NT."
- )
- # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
- num_heads = query.size(1)
- head_dim_qk = query.size(3)
- head_dim_v = value.size(3)
- q_t = query.transpose(1, 2)
- k_t = key.transpose(1, 2)
- v_t = value.transpose(1, 2)
- (
- cumulative_sequence_length_q,
- max_seqlen_batch_q,
- Nnz_q,
- ) = _cumulative_and_max_seq_len_nnz(q_t)
- (
- cumulative_sequence_length_kv,
- max_seqlen_batch_kv,
- Nnz_kv,
- ) = _cumulative_and_max_seq_len_nnz(k_t)
- # [TODO] K and V have to have the same Nnz, should probably torch_check
- # assume in order to not iterate over v
- # If the physical layout of the NestedTensor's storage
- # is not: batch, {seq_len}, num_heads, head_dim then we need
- # to call contiguous
- if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
- q_t = q_t.contiguous()
- if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
- k_t = k_t.contiguous()
- if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
- v_t = v_t.contiguous()
- query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
- key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
- value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
- output_nt_info = {
- "offsets": q_t.offsets(),
- "_max_seqlen": q_t._max_seqlen,
- "_min_seqlen": q_t._min_seqlen,
- }
- return (
- query_buffer_reshaped,
- key_buffer_reshaped,
- value_buffer_reshaped,
- cumulative_sequence_length_q,
- cumulative_sequence_length_kv,
- max_seqlen_batch_q,
- max_seqlen_batch_kv,
- output_nt_info,
- )
- def _pad_last_dim(
- tensor: torch.Tensor, alignment_size: int, slice: bool
- ) -> torch.Tensor:
- # FlashAttentionV2 requires that head dimension be a multiple of 8
- # This was previously done within the kernel, however
- # This causes the kernel to maybe alias query, key, value
- # So instead we pad the head_dimensions to be a multiple of 8
- # in the composite region
- last_dim_size = tensor.size(-1)
- if last_dim_size % alignment_size == 0:
- return tensor
- pad_count = alignment_size - (last_dim_size % alignment_size)
- tensor = torch.nn.functional.pad(tensor, [0, pad_count])
- if slice:
- return tensor[..., 0:last_dim_size]
- return tensor
- # TODO: coalesce with torch/nn/utils/attention.py
- def _calculate_scale(query, scale):
- # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
- softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
- return softmax_scale
- def _post_process_flash_output(out: torch.Tensor, og_size):
- if not out.is_nested and out.size(-1) != og_size:
- out = out[..., 0:og_size]
- return out
- def jagged_scaled_dot_product_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p=0.0,
- is_causal=False,
- scale=None,
- ):
- _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
- # for mypy, ugh
- assert (
- isinstance(query, NestedTensor)
- and isinstance(key, NestedTensor)
- and isinstance(value, NestedTensor)
- )
- # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
- # second batch dim instead). For this case, we can just send the dense buffers through
- # vanilla SDPA.
- if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
- from torch.nested._internal.ops import extract_kwargs
- output = F.scaled_dot_product_attention(
- query._values,
- key._values,
- value._values,
- attn_mask=(
- attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
- ),
- dropout_p=dropout_p,
- is_causal=is_causal,
- scale=scale,
- )
- return NestedTensor(output, **extract_kwargs(query))
- compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
- backend_choice = _select_sdp_backend(
- query, key, value, attn_mask, dropout_p, is_causal
- )
- if backend_choice == SDPBackend.FLASH_ATTENTION:
- og_size = query.size(-1)
- query_padded = _pad_last_dim(query, 8, False)
- key_padded = _pad_last_dim(key, 8, False)
- value_padded = _pad_last_dim(value, 8, False)
- # We need to calculate the scale based off the OG head dim size
- og_scale = _calculate_scale(query, scale)
- (
- query_buffer_reshaped,
- key_buffer_reshaped,
- value_buffer_reshaped,
- cumulative_sequence_length_q,
- cumulative_sequence_length_kv,
- max_seqlen_batch_q,
- max_seqlen_batch_kv,
- output_nt_info,
- ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
- (
- attention,
- logsumexp,
- philox_seed,
- philox_offset,
- debug_attn_mask,
- ) = torch.ops.aten._flash_attention_forward(
- query_buffer_reshaped,
- key_buffer_reshaped,
- value_buffer_reshaped,
- cumulative_sequence_length_q,
- cumulative_sequence_length_kv,
- max_seqlen_batch_q,
- max_seqlen_batch_kv,
- dropout_p,
- is_causal,
- False,
- scale=og_scale,
- )
- # Reshape output to convert nnz to batch_size and seq_len
- attention = ViewNestedFromBuffer.apply(
- attention, # output from flash_attn is [total_q, num_heads, head_size_og]
- output_nt_info["offsets"],
- ).transpose(1, 2)
- return _post_process_flash_output(attention, og_size)
- elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
- (
- query_reshaped,
- key_reshaped,
- value_reshaped,
- cumulative_sequence_length_q,
- cumulative_sequence_length_kv,
- max_seqlen_batch_q,
- max_seqlen_batch_kv,
- output_nt_info,
- ) = _sdpa_nested_preprocessing(query, key, value)
- (
- attention,
- log_sumexp,
- seed,
- offset,
- max_seqlen_q,
- max_seqlen_batch_kv,
- ) = torch.ops.aten._efficient_attention_forward(
- query_reshaped.unsqueeze(0),
- key_reshaped.unsqueeze(0),
- value_reshaped.unsqueeze(0),
- None,
- cumulative_sequence_length_q,
- cumulative_sequence_length_kv,
- max_seqlen_batch_q,
- max_seqlen_batch_kv,
- dropout_p,
- int(is_causal),
- compute_logsumexp,
- scale=scale,
- )
- # Reshape output to convert nnz to batch_size and seq_len
- return ViewNestedFromBuffer.apply(
- attention.squeeze(0), output_nt_info["offsets"]
- ).transpose(1, 2)
- elif backend_choice == SDPBackend.MATH:
- # save the offsets and shape of the inputs, so we can reshape the final output
- # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
- # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
- offsets = query.offsets()
- d1 = query._size[1]
- d2 = value._size[-1]
- # convert jagged layout Nested Tensor to strided layout Nested Tensor
- # which support the math implementation of SDPA
- def get_strided_layout_nested_tensor(jagged_layout_nt):
- lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
- transpose = torch.transpose(jagged_layout_nt, 1, 2)
- tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0)
- strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
- strided_nt = strided_nt.transpose(1, 2).contiguous()
- return strided_nt
- query = get_strided_layout_nested_tensor(query)
- key = get_strided_layout_nested_tensor(key)
- value = get_strided_layout_nested_tensor(value)
- attn_out = torch._scaled_dot_product_attention_math(
- query, key, value, attn_mask, dropout_p, is_causal, scale=scale
- )[0]
- # convert strided layout Nested Tensor back to jagged layout Nested Tensor
- attn_out = attn_out.transpose(1, 2).contiguous().values()
- attn_out = attn_out.view(-1, d1, d2)
- attn_out = ViewNestedFromBuffer.apply(attn_out, offsets)
- attn_out = attn_out.transpose(1, 2)
- return attn_out
- else:
- raise RuntimeError(
- "No viable backend for scaled_dot_product_attention was found."
- )
|