sdpa.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. # mypy: allow-untyped-defs
  2. import logging
  3. from typing import Optional, Tuple
  4. import torch
  5. import torch.nn
  6. import torch.nn.functional as F
  7. from torch.backends.cuda import (
  8. can_use_efficient_attention,
  9. can_use_flash_attention,
  10. flash_sdp_enabled,
  11. math_sdp_enabled,
  12. mem_efficient_sdp_enabled,
  13. SDPAParams,
  14. )
  15. from torch.nn.attention import SDPBackend
  16. from .nested_tensor import buffer_from_jagged, NestedTensor, ViewNestedFromBuffer
  17. log = logging.getLogger(__name__)
  18. def _validate_sdpa_input(
  19. query: torch.Tensor,
  20. key: torch.Tensor,
  21. value: torch.Tensor,
  22. attn_mask: Optional[torch.Tensor] = None,
  23. dropout_p=0.0,
  24. is_causal=False,
  25. scale=None,
  26. ):
  27. if (
  28. not isinstance(query, NestedTensor)
  29. or not isinstance(key, NestedTensor)
  30. or not isinstance(value, NestedTensor)
  31. ):
  32. raise ValueError(
  33. f"Expected query, key, and value to be nested tensors, "
  34. f"but got query.is_nested: {query.is_nested}, key.is_nested: {key.is_nested}, "
  35. f"and value.is_nested: {value.is_nested} instead."
  36. )
  37. if query.dtype != key.dtype or query.dtype != value.dtype:
  38. raise ValueError(
  39. f"Expected query, key, and value to have the same dtype, "
  40. f"but got query.dtype: {query.dtype}, key.dtype: {key.dtype}, "
  41. f"and value.dtype: {value.dtype} instead."
  42. )
  43. if query.device != key.device or query.device != value.device:
  44. raise ValueError(
  45. f"Expected query, key, and value to have the same device type, "
  46. f"but got query.device: {query.device}, key.device: {key.device}, "
  47. f"and value.device: {value.device} instead."
  48. )
  49. if query.dim() < 2 or key.dim() < 2 or value.dim() < 2:
  50. raise ValueError(
  51. f"Expected query, key, and value to all be at least 2 dimensional, but got query.dim: "
  52. f"{query.dim()}, key.dim: {key.dim()} and value.dim: {value.dim()} instead."
  53. )
  54. if query._ragged_idx != key._ragged_idx or query._ragged_idx != value._ragged_idx:
  55. raise ValueError(
  56. f"Expected query, key, and value to all be ragged on the same dimension, but got ragged "
  57. f"dims {query._ragged_idx}, {key._ragged_idx}, and {value._ragged_idx}, respectively."
  58. )
  59. if attn_mask is not None:
  60. # TODO: Figure out whether masks are actually supported for this layout or not
  61. raise ValueError("Masks are not yet supported!")
  62. if attn_mask.dtype != torch.bool and attn_mask.dtype != query.dtype:
  63. raise ValueError(
  64. f"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: "
  65. f"{attn_mask.dtype}, and query.dtype: {query.dtype} instead."
  66. )
  67. def _check_batch_size_nested(params: SDPAParams, debug=False) -> bool:
  68. # This is expected to be called after check_tensor_shapes ensuring that the
  69. # size() calls won't error since the inputs are all 4 dimensional
  70. q_batch_size = params.query.size(0)
  71. k_batch_size = params.key.size(0)
  72. v_batch_size = params.value.size(0)
  73. # num_heads logic for nested input is checked in
  74. # check_for_seq_len_0_nested_tensor as there is handling there to make sure
  75. # num_heads is not ragged
  76. return q_batch_size == k_batch_size and q_batch_size == v_batch_size
  77. def _check_head_dim_size_flash_nested(params: SDPAParams, debug=False) -> bool:
  78. max_size = 256
  79. query_size_last = params.query.size(-1)
  80. key_size_last = params.key.size(-1)
  81. value_size_last = params.value.size(-1)
  82. same_head_dim_size = (
  83. query_size_last == key_size_last and query_size_last == value_size_last
  84. )
  85. if not (
  86. same_head_dim_size
  87. and (query_size_last % 8 == 0)
  88. and (query_size_last <= max_size)
  89. ):
  90. if debug:
  91. log.warning(
  92. "For NestedTensor inputs, Flash attention requires q,k,v to have the same "
  93. "last dimension and to be a multiple of 8 and less than or equal to 256. "
  94. "Got Query.size(-1): %d, Key.size(-1): %d, Value.size(-1): %d instead.",
  95. query_size_last,
  96. key_size_last,
  97. value_size_last,
  98. )
  99. return False
  100. return True
  101. def _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  102. param: torch.Tensor, param_name: str, debug=False
  103. ) -> bool:
  104. assert isinstance(param, NestedTensor), "param should be a jagged NT"
  105. if param._ragged_idx == 1:
  106. # num_head_dims is ragged
  107. if debug:
  108. log.warning(
  109. "Fused kernels do not support ragged num_head_dims, %s has a ragged num_heads.",
  110. param_name,
  111. )
  112. return False
  113. # This is being called inside sdp with shape [batch, heads, {seq_len}, dim]
  114. if param._min_seqlen == 0:
  115. if debug:
  116. log.warning(
  117. "Fused kernels do not support seq_len == 0, %s has a seq len of 0.",
  118. param_name,
  119. )
  120. return False
  121. return True
  122. def _try_broadcast_param_size(q_size, k_size, v_size, param_name, debug=False) -> bool:
  123. max_size = max(q_size, k_size, v_size)
  124. if (
  125. (q_size != max_size and q_size != 1)
  126. or (k_size != max_size and k_size != 1)
  127. or (v_size != max_size and v_size != 1)
  128. ):
  129. if debug:
  130. log.warning(
  131. "Both fused kernels require query, key and value to have broadcastable %s, "
  132. "got Query %s %d, Key %s %d, Value %s %d instead.",
  133. param_name,
  134. param_name,
  135. q_size,
  136. param_name,
  137. k_size,
  138. param_name,
  139. v_size,
  140. )
  141. return False
  142. return True
  143. def _check_for_seq_len_0_nested(params: SDPAParams, debug=False) -> bool:
  144. # When this function is called we are assured that the nt is dim==4
  145. q_is_safe = (
  146. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  147. params.query, "query", debug
  148. )
  149. if params.query.is_nested
  150. else True
  151. )
  152. # short circuit if any is unsafe
  153. if not q_is_safe:
  154. return False
  155. k_is_safe = (
  156. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  157. params.key, "key", debug
  158. )
  159. if params.key.is_nested
  160. else True
  161. )
  162. # short circuit if any is unsafe
  163. if not k_is_safe:
  164. return False
  165. v_is_safe = (
  166. _check_for_seq_len_0_and_consistent_head_dim_nested_helper(
  167. params.value, "value", debug
  168. )
  169. if params.value.is_nested
  170. else True
  171. )
  172. # short circuit if any is unsafe
  173. if not v_is_safe:
  174. return False
  175. # We now know none of the inputs have ragged num_heads, so we can safely
  176. # access .size(1)
  177. q_num_heads = params.query.size(1)
  178. k_num_heads = params.key.size(1)
  179. v_num_heads = params.value.size(1)
  180. same_num_heads = q_num_heads == k_num_heads and q_num_heads == v_num_heads
  181. if not same_num_heads:
  182. if (
  183. params.query.requires_grad
  184. or params.key.requires_grad
  185. or params.value.requires_grad
  186. ):
  187. if debug:
  188. log.warning(
  189. "Both fused kernels do not support training with broadcasted NT inputs."
  190. )
  191. return False
  192. return _try_broadcast_param_size(
  193. q_num_heads, k_num_heads, v_num_heads, "num heads", debug
  194. )
  195. return True
  196. def _can_use_flash_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  197. constraints = (
  198. _check_batch_size_nested,
  199. _check_head_dim_size_flash_nested,
  200. _check_for_seq_len_0_nested,
  201. )
  202. for constraint in constraints:
  203. if not constraint(params, debug):
  204. return False
  205. return True
  206. def _can_use_efficient_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  207. constraints = (
  208. _check_batch_size_nested,
  209. _check_for_seq_len_0_nested,
  210. )
  211. for constraint in constraints:
  212. if not constraint(params, debug):
  213. return False
  214. return True
  215. def _can_use_math_sdpa_jagged(params: SDPAParams, debug=False) -> bool:
  216. if (
  217. not params.query.transpose(1, 2).is_contiguous()
  218. or not params.key.transpose(1, 2).is_contiguous()
  219. or not params.value.transpose(1, 2).is_contiguous()
  220. ):
  221. if debug:
  222. log.warning(
  223. "If inputs are nested tensors they must be contiguous after transposing."
  224. )
  225. return False
  226. if params.is_causal:
  227. if debug:
  228. log.warning(
  229. "Nested tensors for query / key are not supported when is_causal=True."
  230. )
  231. return False
  232. return True
  233. def _select_sdp_backend(query, key, value, attn_mask, dropout, is_causal):
  234. if (
  235. not flash_sdp_enabled()
  236. and not mem_efficient_sdp_enabled()
  237. and not math_sdp_enabled()
  238. ):
  239. return SDPBackend.ERROR
  240. ordering = (
  241. SDPBackend.FLASH_ATTENTION,
  242. SDPBackend.EFFICIENT_ATTENTION,
  243. SDPBackend.MATH,
  244. )
  245. params = SDPAParams(query, key, value, attn_mask, dropout, is_causal)
  246. for backend in ordering:
  247. if backend == SDPBackend.FLASH_ATTENTION:
  248. if can_use_flash_attention(params) and _can_use_flash_sdpa_jagged(params):
  249. return SDPBackend.FLASH_ATTENTION
  250. if backend == SDPBackend.EFFICIENT_ATTENTION:
  251. if can_use_efficient_attention(params) and _can_use_efficient_sdpa_jagged(
  252. params
  253. ):
  254. return SDPBackend.EFFICIENT_ATTENTION
  255. if backend == SDPBackend.MATH:
  256. if math_sdp_enabled() and _can_use_math_sdpa_jagged(params):
  257. return SDPBackend.MATH
  258. log.warning("Memory efficient kernel not used because:")
  259. can_use_efficient_attention(params, debug=True)
  260. _can_use_efficient_sdpa_jagged(params, debug=True)
  261. log.warning("Flash attention kernel not used because:")
  262. can_use_flash_attention(params, debug=True)
  263. _can_use_flash_sdpa_jagged(params, debug=True)
  264. log.warning("Math attention kernel not used because:")
  265. _can_use_math_sdpa_jagged(params, debug=True)
  266. return SDPBackend.ERROR
  267. def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
  268. # This function is used to calculate two pieces of metadata that are needed
  269. # for use with flash-attention and efficient_attention kernels. They are the
  270. # cumulative sequence_length over a batch of sequences and the maximum
  271. # sequence length.
  272. # It returns a tuple of cumulative sequence lengths and the maximum sequence
  273. # length, and the last element in the cumulative_sequence_lengths
  274. if not isinstance(qkv, NestedTensor):
  275. raise ValueError("QKV must be nested for flash cumulative_seq_len calculation.")
  276. if qkv.lengths() is None:
  277. # TODO: Explore performance impact of copying
  278. cumulative_seqlen = qkv.offsets().to(dtype=torch.int32, device=qkv.device)
  279. max_seqlen = qkv._max_seqlen
  280. n_elem = qkv.values().shape[0]
  281. else:
  282. # TODO: Explore performance impact of copying
  283. cumulative_seqlen = (
  284. qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device)
  285. )
  286. batch_size = qkv.size(0)
  287. max_seqlen = qkv._max_seqlen
  288. # TODO: Explore performance impact when compiling
  289. n_elem = int(cumulative_seqlen[-1].item())
  290. return cumulative_seqlen, max_seqlen, n_elem
  291. def _is_safe_to_get_storage_as_tensor(tensor: torch.Tensor):
  292. # This function checks if a nested tensor is valid for
  293. # use with the flash-attention and efficient_attention kernels without
  294. # needing to call contiguous on the nested tensor input.
  295. # It checks that the storage offsets' adjacent_differences are a constant
  296. # mutiple of the previous tensor in the nested tensor and that the strides
  297. # are monitonically decreasing. This check is done after calling transpose on
  298. # the nested tensor resulting in a Nt of shape [bsz, {seq_len}, num_heads, dim]
  299. # Returns a boolean indicating if contiguous needs to be called for input
  300. assert isinstance(tensor, NestedTensor)
  301. offsets = tensor.offsets()
  302. strides = tensor._strides
  303. n_tensors = offsets.size(0) - 1
  304. if n_tensors <= 1:
  305. return True
  306. # Check initially that the tensor strides are in strictly descending order
  307. prev_stride = strides[1]
  308. for stride in strides[2:]:
  309. if prev_stride <= stride:
  310. # This would mean that the last stride is greater than the seq_len
  311. # stride
  312. return False
  313. prev_stride = stride
  314. # Congrats you made it!
  315. return True
  316. def _view_as_dense(
  317. tensor: torch.Tensor, Nnz: int, num_heads: int, head_dim: int
  318. ) -> torch.Tensor:
  319. if tensor.is_nested:
  320. return buffer_from_jagged(tensor)
  321. return tensor.view(Nnz, num_heads, head_dim)
  322. # TODO: Next iteration should add test cases and check it works
  323. # def _sdpa_nested_preprocessing_with_broadcast(query, key, value):
  324. # # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
  325. # # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  326. # # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  327. # q_batch_size = query.size(0)
  328. # k_batch_size = key.size(0)
  329. # v_batch_size = value.size(0)
  330. # output_batch_size = max(q_batch_size, k_batch_size, v_batch_size)
  331. # q_num_heads = query.size(1)
  332. # k_num_heads = key.size(1)
  333. # v_num_heads = value.size(1)
  334. # output_num_heads = max(q_num_heads, k_num_heads, v_num_heads)
  335. # head_dim_qk = query.size(3)
  336. # head_dim_v = value.size(3)
  337. # q_t = query.transpose(1, 2)
  338. # k_t = key.transpose(1, 2)
  339. # v_t = value.transpose(1, 2)
  340. # # Checks in sdp_utils ensure that if {*}_batch_size/{*}_num_heads !=
  341. # # output_batch_size/num_heads then they are 1
  342. # q_batch_size_needs_broadcast = q_batch_size != output_batch_size
  343. # k_batch_size_needs_broadcast = k_batch_size != output_batch_size
  344. # v_batch_size_needs_broadcast = v_batch_size != output_batch_size
  345. # # If {*}_batch_size_needs_broadcast, then
  346. # # (1) max_seqlen_batch_{*} is given by {*}_t.size(1)
  347. # # this is because needs_broadcast indicates that the batch_size is 1
  348. # # and hence there is only 1 value for seq_len
  349. # # (2) The cum_seq_lens are given by [0, {*}_t.size(1), 2 * {*}_t.size(1),
  350. # # ..., outut_batch_size * {*}_t.size(1)]
  351. # # (3) Nnz_{*} is given by output_batch_size * {*}_t.size(1)
  352. # if q_batch_size_needs_broadcast or not q_t.is_nested:
  353. # max_seqlen_batch_q = q_t.size(1)
  354. # cumulative_sequence_length_q = torch.arange(
  355. # 0,
  356. # (output_batch_size + 1) * max_seqlen_batch_q,
  357. # max_seqlen_batch_q,
  358. # device=q_t.device,
  359. # dtype=torch.int32,
  360. # )
  361. # Nnz_q = output_batch_size * max_seqlen_batch_q
  362. # else:
  363. # (
  364. # cumulative_sequence_length_q,
  365. # max_seqlen_batch_q,
  366. # Nnz_q,
  367. # ) = _cumulative_and_max_seq_len_nnz(q_t)
  368. # if k_batch_size_needs_broadcast and v_batch_size_needs_broadcast:
  369. # assert k_t.size(1) == v_t.size(1)
  370. # max_seqlen_batch_kv = k_t.size(1)
  371. # cumulative_sequence_length_kv = torch.arange(
  372. # 0,
  373. # (output_batch_size + 1) * max_seqlen_batch_kv,
  374. # max_seqlen_batch_kv,
  375. # device=k_t.device,
  376. # dtype=torch.int32,
  377. # )
  378. # Nnz_kv = output_batch_size * max_seqlen_batch_kv
  379. # else:
  380. # cumulative_sequence_length_kv, max_seqlen_batch_kv, Nnz_kv = (
  381. # _cumulative_and_max_seq_len_nnz(v_t)
  382. # if k_batch_size_needs_broadcast
  383. # else _cumulative_and_max_seq_len_nnz(k_t)
  384. # )
  385. # q_num_heads_needs_broadcast = q_num_heads != output_num_heads
  386. # k_num_heads_needs_broadcast = k_num_heads != output_num_heads
  387. # v_num_heads_needs_broadcast = v_num_heads != output_num_heads
  388. # if not q_t.is_nested:
  389. # query_buffer_reshaped = q_t.expand(
  390. # output_batch_size, q_t.size(1), output_num_heads, head_dim_qk
  391. # )
  392. # query_buffer_reshaped = query_buffer_reshaped.reshape(
  393. # Nnz_q, output_num_heads, head_dim_qk
  394. # )
  395. # else:
  396. # if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
  397. # q_t = q_t.contiguous()
  398. # # If we are broadcasting then Nnz_q will be the output_batch_size since
  399. # # seq_len is 1
  400. # effective_batch_size_q = (
  401. # output_batch_size if q_batch_size_needs_broadcast else Nnz_q
  402. # )
  403. # query_buffer_reshaped = _view_as_dense(
  404. # q_t, effective_batch_size_q, output_num_heads, head_dim_qk
  405. # )
  406. # # If the physical layout of the NestedTensor's storage
  407. # # is not: batch, {seq_len}, num_heads, head_dim then we need
  408. # # to call contiguous
  409. # if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
  410. # k_t = k_t.contiguous()
  411. # if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
  412. # v_t = v_t.contiguous()
  413. # effective_batch_size_k = (
  414. # output_batch_size if k_batch_size_needs_broadcast else Nnz_kv
  415. # )
  416. # key_buffer_reshaped = _view_as_dense(
  417. # k_t, effective_batch_size_k, output_num_heads, head_dim_qk
  418. # )
  419. # effective_batch_size_v = (
  420. # output_batch_size if v_batch_size_needs_broadcast else Nnz_kv
  421. # )
  422. # value_buffer_reshaped = _view_as_dense(
  423. # v_t, effective_batch_size_v, output_num_heads, head_dim_v
  424. # )
  425. # if not q_batch_size_needs_broadcast:
  426. # output_shape = q_t._size
  427. # if head_dim_v != head_dim_qk:
  428. # output_shape[-1] = head_dim_v
  429. # if q_num_heads_needs_broadcast:
  430. # output_shape[1] = output_num_heads
  431. # else:
  432. # output_shape = torch.empty(3, dtype=torch.int64, device=torch.device("cpu"))
  433. # output_shape[0] = q_t.size(1)
  434. # output_shape[1] = output_num_heads
  435. # output_shape[2] = head_dim_v
  436. # return (
  437. # query_buffer_reshaped,
  438. # key_buffer_reshaped,
  439. # value_buffer_reshaped,
  440. # cumulative_sequence_length_q,
  441. # cumulative_sequence_length_kv,
  442. # max_seqlen_batch_q,
  443. # max_seqlen_batch_kv,
  444. # output_shape,
  445. # )
  446. def _sdpa_nested_preprocessing(query, key, value):
  447. # Query (Batch x Num_heads x {Q_seq_len} x Dim_per_head)
  448. # Key (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  449. # Value (Batch x Num_heads x {KV_seq_len} x Dim_per_head)
  450. q_batch_size = query.size(0)
  451. k_batch_size = key.size(0)
  452. v_batch_size = value.size(0)
  453. q_num_heads = query.size(1)
  454. k_num_heads = key.size(1)
  455. v_num_heads = value.size(1)
  456. if not (q_batch_size == k_batch_size and q_batch_size == v_batch_size) or not (
  457. q_num_heads == k_num_heads and k_num_heads == v_num_heads
  458. ):
  459. raise RuntimeError(
  460. "This path is currently not implemented for jagged layout NT."
  461. )
  462. # return _sdpa_nested_preprocessing_with_broadcast(query, key, value)
  463. num_heads = query.size(1)
  464. head_dim_qk = query.size(3)
  465. head_dim_v = value.size(3)
  466. q_t = query.transpose(1, 2)
  467. k_t = key.transpose(1, 2)
  468. v_t = value.transpose(1, 2)
  469. (
  470. cumulative_sequence_length_q,
  471. max_seqlen_batch_q,
  472. Nnz_q,
  473. ) = _cumulative_and_max_seq_len_nnz(q_t)
  474. (
  475. cumulative_sequence_length_kv,
  476. max_seqlen_batch_kv,
  477. Nnz_kv,
  478. ) = _cumulative_and_max_seq_len_nnz(k_t)
  479. # [TODO] K and V have to have the same Nnz, should probably torch_check
  480. # assume in order to not iterate over v
  481. # If the physical layout of the NestedTensor's storage
  482. # is not: batch, {seq_len}, num_heads, head_dim then we need
  483. # to call contiguous
  484. if not q_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(q_t):
  485. q_t = q_t.contiguous()
  486. if not k_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(k_t):
  487. k_t = k_t.contiguous()
  488. if not v_t.is_contiguous() and not _is_safe_to_get_storage_as_tensor(v_t):
  489. v_t = v_t.contiguous()
  490. query_buffer_reshaped = _view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk)
  491. key_buffer_reshaped = _view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk)
  492. value_buffer_reshaped = _view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v)
  493. output_nt_info = {
  494. "offsets": q_t.offsets(),
  495. "_max_seqlen": q_t._max_seqlen,
  496. "_min_seqlen": q_t._min_seqlen,
  497. }
  498. return (
  499. query_buffer_reshaped,
  500. key_buffer_reshaped,
  501. value_buffer_reshaped,
  502. cumulative_sequence_length_q,
  503. cumulative_sequence_length_kv,
  504. max_seqlen_batch_q,
  505. max_seqlen_batch_kv,
  506. output_nt_info,
  507. )
  508. def _pad_last_dim(
  509. tensor: torch.Tensor, alignment_size: int, slice: bool
  510. ) -> torch.Tensor:
  511. # FlashAttentionV2 requires that head dimension be a multiple of 8
  512. # This was previously done within the kernel, however
  513. # This causes the kernel to maybe alias query, key, value
  514. # So instead we pad the head_dimensions to be a multiple of 8
  515. # in the composite region
  516. last_dim_size = tensor.size(-1)
  517. if last_dim_size % alignment_size == 0:
  518. return tensor
  519. pad_count = alignment_size - (last_dim_size % alignment_size)
  520. tensor = torch.nn.functional.pad(tensor, [0, pad_count])
  521. if slice:
  522. return tensor[..., 0:last_dim_size]
  523. return tensor
  524. # TODO: coalesce with torch/nn/utils/attention.py
  525. def _calculate_scale(query, scale):
  526. # TODO: Investigate why math.sqrt() isn't properly handled by Dynamo?
  527. softmax_scale = scale if scale is not None else torch.sym_sqrt(1.0 / query.size(-1))
  528. return softmax_scale
  529. def _post_process_flash_output(out: torch.Tensor, og_size):
  530. if not out.is_nested and out.size(-1) != og_size:
  531. out = out[..., 0:og_size]
  532. return out
  533. def jagged_scaled_dot_product_attention(
  534. query: torch.Tensor,
  535. key: torch.Tensor,
  536. value: torch.Tensor,
  537. attn_mask: Optional[torch.Tensor] = None,
  538. dropout_p=0.0,
  539. is_causal=False,
  540. scale=None,
  541. ):
  542. _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
  543. # for mypy, ugh
  544. assert (
  545. isinstance(query, NestedTensor)
  546. and isinstance(key, NestedTensor)
  547. and isinstance(value, NestedTensor)
  548. )
  549. # Special path for non-ragged sequence length (e.g. for SAM where we have a ragged
  550. # second batch dim instead). For this case, we can just send the dense buffers through
  551. # vanilla SDPA.
  552. if query.dim() > 3 and key.dim() > 3 and value.dim() > 3 and query._ragged_idx == 1:
  553. from torch.nested._internal.ops import extract_kwargs
  554. output = F.scaled_dot_product_attention(
  555. query._values,
  556. key._values,
  557. value._values,
  558. attn_mask=(
  559. attn_mask._values if isinstance(attn_mask, NestedTensor) else attn_mask
  560. ),
  561. dropout_p=dropout_p,
  562. is_causal=is_causal,
  563. scale=scale,
  564. )
  565. return NestedTensor(output, **extract_kwargs(query))
  566. compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad
  567. backend_choice = _select_sdp_backend(
  568. query, key, value, attn_mask, dropout_p, is_causal
  569. )
  570. if backend_choice == SDPBackend.FLASH_ATTENTION:
  571. og_size = query.size(-1)
  572. query_padded = _pad_last_dim(query, 8, False)
  573. key_padded = _pad_last_dim(key, 8, False)
  574. value_padded = _pad_last_dim(value, 8, False)
  575. # We need to calculate the scale based off the OG head dim size
  576. og_scale = _calculate_scale(query, scale)
  577. (
  578. query_buffer_reshaped,
  579. key_buffer_reshaped,
  580. value_buffer_reshaped,
  581. cumulative_sequence_length_q,
  582. cumulative_sequence_length_kv,
  583. max_seqlen_batch_q,
  584. max_seqlen_batch_kv,
  585. output_nt_info,
  586. ) = _sdpa_nested_preprocessing(query_padded, key_padded, value_padded)
  587. (
  588. attention,
  589. logsumexp,
  590. philox_seed,
  591. philox_offset,
  592. debug_attn_mask,
  593. ) = torch.ops.aten._flash_attention_forward(
  594. query_buffer_reshaped,
  595. key_buffer_reshaped,
  596. value_buffer_reshaped,
  597. cumulative_sequence_length_q,
  598. cumulative_sequence_length_kv,
  599. max_seqlen_batch_q,
  600. max_seqlen_batch_kv,
  601. dropout_p,
  602. is_causal,
  603. False,
  604. scale=og_scale,
  605. )
  606. # Reshape output to convert nnz to batch_size and seq_len
  607. attention = ViewNestedFromBuffer.apply(
  608. attention, # output from flash_attn is [total_q, num_heads, head_size_og]
  609. output_nt_info["offsets"],
  610. ).transpose(1, 2)
  611. return _post_process_flash_output(attention, og_size)
  612. elif backend_choice == SDPBackend.EFFICIENT_ATTENTION:
  613. (
  614. query_reshaped,
  615. key_reshaped,
  616. value_reshaped,
  617. cumulative_sequence_length_q,
  618. cumulative_sequence_length_kv,
  619. max_seqlen_batch_q,
  620. max_seqlen_batch_kv,
  621. output_nt_info,
  622. ) = _sdpa_nested_preprocessing(query, key, value)
  623. (
  624. attention,
  625. log_sumexp,
  626. seed,
  627. offset,
  628. max_seqlen_q,
  629. max_seqlen_batch_kv,
  630. ) = torch.ops.aten._efficient_attention_forward(
  631. query_reshaped.unsqueeze(0),
  632. key_reshaped.unsqueeze(0),
  633. value_reshaped.unsqueeze(0),
  634. None,
  635. cumulative_sequence_length_q,
  636. cumulative_sequence_length_kv,
  637. max_seqlen_batch_q,
  638. max_seqlen_batch_kv,
  639. dropout_p,
  640. int(is_causal),
  641. compute_logsumexp,
  642. scale=scale,
  643. )
  644. # Reshape output to convert nnz to batch_size and seq_len
  645. return ViewNestedFromBuffer.apply(
  646. attention.squeeze(0), output_nt_info["offsets"]
  647. ).transpose(1, 2)
  648. elif backend_choice == SDPBackend.MATH:
  649. # save the offsets and shape of the inputs, so we can reshape the final output
  650. # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1]
  651. # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2]
  652. offsets = query.offsets()
  653. d1 = query._size[1]
  654. d2 = value._size[-1]
  655. # convert jagged layout Nested Tensor to strided layout Nested Tensor
  656. # which support the math implementation of SDPA
  657. def get_strided_layout_nested_tensor(jagged_layout_nt):
  658. lengths = jagged_layout_nt._offsets[1:] - jagged_layout_nt._offsets[:-1]
  659. transpose = torch.transpose(jagged_layout_nt, 1, 2)
  660. tensor_list = buffer_from_jagged(transpose).split(list(lengths), dim=0)
  661. strided_nt = torch.nested.as_nested_tensor(list(tensor_list))
  662. strided_nt = strided_nt.transpose(1, 2).contiguous()
  663. return strided_nt
  664. query = get_strided_layout_nested_tensor(query)
  665. key = get_strided_layout_nested_tensor(key)
  666. value = get_strided_layout_nested_tensor(value)
  667. attn_out = torch._scaled_dot_product_attention_math(
  668. query, key, value, attn_mask, dropout_p, is_causal, scale=scale
  669. )[0]
  670. # convert strided layout Nested Tensor back to jagged layout Nested Tensor
  671. attn_out = attn_out.transpose(1, 2).contiguous().values()
  672. attn_out = attn_out.view(-1, d1, d2)
  673. attn_out = ViewNestedFromBuffer.apply(attn_out, offsets)
  674. attn_out = attn_out.transpose(1, 2)
  675. return attn_out
  676. else:
  677. raise RuntimeError(
  678. "No viable backend for scaled_dot_product_attention was found."
  679. )