diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index c3c769d5ead..6283f079833 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -5,8 +5,6 @@ cmake_policy(SET CMP0169 OLD) include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) -set(BUILD_FA3, OFF) - find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED) enable_language(CUDA) @@ -80,7 +78,6 @@ include_directories( ${repo-cutlass_SOURCE_DIR}/examples/common ${repo-flashinfer_SOURCE_DIR}/include ${repo-flashinfer_SOURCE_DIR}/csrc - ${repo-flash-attention_SOURCE_DIR}/hopper ) set(CMAKE_CXX_STANDARD 17) @@ -115,6 +112,9 @@ option(SGL_KERNEL_ENABLE_BF16 "Enable BF16" ON) option(SGL_KERNEL_ENABLE_FP8 "Enable FP8" ON) option(SGL_KERNEL_ENABLE_FP4 "Enable FP4" OFF) +option(SGL_KERNEL_ENABLE_FA3 "Enable FA3" OFF) + + if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_100,code=sm_100" @@ -127,7 +127,7 @@ else() endif() if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A) - set(BUILD_FA3 ON) + set(SGL_KERNEL_ENABLE_FA3 ON) list(APPEND SGL_KERNEL_CUDA_FLAGS "-gencode=arch=compute_90a,code=sm_90a" ) @@ -187,11 +187,33 @@ set(SOURCES "${repo-flashinfer_SOURCE_DIR}/csrc/norm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/renorm.cu" "${repo-flashinfer_SOURCE_DIR}/csrc/sampling.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_causal_sm80.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_bf16_sm80.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_causal_sm80.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src/flash_fwd_sparse_hdim128_fp16_sm80.cu" + "${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/flash_sparse_api.cpp" ) +Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) + +target_compile_options(common_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) +target_include_directories(common_ops PRIVATE + ${TORCH_INCLUDE_DIRS} + ${repo-flash-attention_SOURCE_DIR}/csrc/flash_attn/src) +target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) + +target_compile_definitions(common_ops PRIVATE + FLASHATTENTION_DISABLE_BACKWARD + FLASHATTENTION_DISABLE_DROPOUT + FLASHATTENTION_DISABLE_UNEVEN_K + ) + +install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") + +# ============================ Optional Install ============================= # # set flash-attention sources file # BF16 source files -if (BUILD_FA3) +if (SGL_KERNEL_ENABLE_FA3) set(SGL_FLASH_KERNEL_CUDA_FLAGS "-DNDEBUG" "-DOPERATOR_NAMESPACE=sgl-kernel" @@ -246,7 +268,9 @@ if (BUILD_FA3) Python_add_library(flash_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${FLASH_SOURCES}) target_compile_options(flash_ops PRIVATE $<$:${SGL_FLASH_KERNEL_CUDA_FLAGS}>) - target_include_directories(flash_ops PRIVATE ${TORCH_INCLUDE_DIRS}) + target_include_directories(flash_ops PRIVATE + ${TORCH_INCLUDE_DIRS} + ${repo-flash-attention_SOURCE_DIR}/hopper) target_link_libraries(flash_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda) install(TARGETS flash_ops LIBRARY DESTINATION "sgl_kernel") @@ -260,14 +284,6 @@ if (BUILD_FA3) ) endif() -Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES}) - -target_compile_options(common_ops PRIVATE $<$:${SGL_KERNEL_CUDA_FLAGS}>) -target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS}) -target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} c10 cuda cublas cublasLt) - -install(TARGETS common_ops LIBRARY DESTINATION "sgl_kernel") - # JIT Logic # DeepGEMM diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 346b2e133e2..c2086aa5b5b 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -206,6 +206,28 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + + /* + * From Sparse Flash Attention + */ + m.def( + "fwd_sparse(Tensor! q, Tensor k, Tensor v, " + "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " + "Tensor!? out, Tensor? alibi_slopes, " + "float p_dropout, float softmax_scale, bool is_causal, " + "float softcap, bool return_softmax, Generator? gen)" + "-> Tensor[]"); + m.impl("fwd_sparse", torch::kCUDA, &flash::mha_fwd_sparse); + + m.def( + "varlen_fwd_sparse(Tensor! q, Tensor k, Tensor v, " + "Tensor block_count, Tensor block_offset, Tensor column_count, Tensor column_index, " + "Tensor!? out, Tensor cu_seqlens_q, " + "Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? alibi_slopes, " + "int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, " + "bool is_causal, float softcap, bool return_softmax, " + "Generator? gen) -> Tensor[]"); + m.impl("varlen_fwd_sparse", torch::kCUDA, &flash::mha_varlen_fwd_sparse); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index d1222b1dd4f..07046800df2 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -256,18 +256,21 @@ void min_p_sampling_from_probs( double min_p_val, bool deterministic, int64_t cuda_stream); + void top_k_renorm_probs( at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, int64_t top_k_val, int64_t cuda_stream); + void top_p_renorm_probs( at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val, int64_t cuda_stream); + void top_k_top_p_sampling_from_probs( at::Tensor probs, at::Tensor uniform_samples, @@ -279,6 +282,7 @@ void top_k_top_p_sampling_from_probs( double top_p_val, bool deterministic, int64_t cuda_stream); + void top_p_sampling_from_probs( at::Tensor probs, at::Tensor uniform_samples, @@ -288,3 +292,49 @@ void top_p_sampling_from_probs( double top_p_val, bool deterministic, int64_t cuda_stream); + +namespace flash { +/* + * From fa2 sparse + */ +std::vector mha_fwd_sparse( + at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor& block_count, + const at::Tensor& block_offset, + const at::Tensor& column_count, + const at::Tensor& column_index, + const std::optional& out_, // batch_size x seqlen_q x num_heads x head_size + const std::optional& alibi_slopes_, // num_heads or batch_size x num_heads + const double p_dropout, + const double softmax_scale, + bool is_causal, + const double softcap, + const bool return_softmax, + std::optional gen_); + +std::vector mha_varlen_fwd_sparse( + at::Tensor& q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const at::Tensor& k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i. + const at::Tensor& v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i. + const at::Tensor& block_count, + const at::Tensor& block_offset, + const at::Tensor& column_count, + const at::Tensor& column_index, + const c10::optional& out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + const c10::optional& + seqused_k, // b. If given, only this many elements of each batch element's keys are used. + const c10::optional& alibi_slopes_, // num_heads or b x num_heads + int64_t max_seqlen_q, + const int64_t max_seqlen_k, + const double p_dropout, + const double softmax_scale, + const bool zero_tensors, + bool is_causal, + const double softcap, + const bool return_softmax, + c10::optional gen_); +} // namespace flash diff --git a/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py b/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py new file mode 100644 index 00000000000..c4ffad7daa9 --- /dev/null +++ b/sgl-kernel/python/sgl_kernel/sparse_flash_attn.py @@ -0,0 +1,175 @@ +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +def sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + softcap, + return_attn_probs and dropout_p > 0, + None, + ) + return (out, softmax_lse) if return_softmax_lse else out + + +def sparse_attn_varlen_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + *, + return_softmax_lse=False, + out=None, +): + """Compute attention with vertical and slash sparsity patterns. + Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: + block_count and block_offset for slash sparsity patterns, and + column_count and column_index for vertical sparsity patterns. + For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) + column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) + column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + out, + cu_seqlens_q, + cu_seqlens_k, + None, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + softcap, + return_attn_probs and dropout_p > 0, + None, + ) + return (out, softmax_lse) if return_softmax_lse else out diff --git a/sgl-kernel/tests/test_sparse_flash_attn.py b/sgl-kernel/tests/test_sparse_flash_attn.py new file mode 100644 index 00000000000..bb964f33532 --- /dev/null +++ b/sgl-kernel/tests/test_sparse_flash_attn.py @@ -0,0 +1,348 @@ +import math +from typing import List, Optional, Tuple + +import pytest +import torch +from einops import rearrange, repeat +from sgl_kernel.sparse_flash_attn import sparse_attn_func, sparse_attn_varlen_func + + +def ref_attn( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, + upcast=True, + reorder_ops=False, + key_leftpad=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + window_size: (int, int), left and right window size + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + lse: (batch_size, nheads, seqlen_q) + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + + lse_ref = scores.logsumexp(dim=-1) + + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap + if key_padding_mask is not None: + scores.masked_fill_( + rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") + ) + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + query_padding_mask, + key_padding_mask, + q.device, + key_leftpad=key_leftpad, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if window_size[0] >= 0 or window_size[1] >= 0: + attention = attention.masked_fill( + torch.all(local_mask, dim=-1, keepdim=True), 0.0 + ) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill( + rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 + ) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + + return output.to(dtype=dtype_og), lse_ref + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + # clone to avoid clobbering the query tensor + q = query[start_idx : start_idx + query_len].clone() + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = ( + torch.triu( + empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 + ) + .bool() + .logical_not() + ) + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("batch_size", [1, 2]) +@pytest.mark.parametrize( + "seq_lens", + [ + (1, 1), + (1, 1024), + (1, 2048), + (1023, 2049), + (1023, 1023), + (32, 32), + (65, 65), + (129, 129), + ], +) +@pytest.mark.parametrize("num_heads", [1, 2, 4]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32]) +@torch.inference_mode() +def test_sparse_attention( + batch_size, + seq_lens, + num_heads, + head_size, + dtype, + NNZ_S, +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + block_size_M = 64 + block_size_N = 64 + seqlen_q, seqlen_k = seq_lens + q = torch.randn( + batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False + ) + k = torch.randn( + batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False + ) + v = torch.randn( + batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False + ) + NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M + if NNZ_S * block_size_N > seqlen_k: + return + NNZ_V = seqlen_k - NNZ_S * block_size_N + block_count = torch.tensor( + [NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32 + ).reshape(batch_size, num_heads, NUM_ROWS) + column_count = torch.tensor( + [NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32 + ).reshape(batch_size, num_heads, NUM_ROWS) + block_offset = torch.tensor( + [[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads, + dtype=torch.int32, + ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) + column_index = torch.tensor( + [[NNZ_S * block_size_N + i for i in range(NNZ_V)]] + * batch_size + * NUM_ROWS + * num_heads, + dtype=torch.int32, + ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) + out, lse = sparse_attn_func( + q, + k, + v, + block_count, + block_offset, + column_count, + column_index, + return_softmax_lse=True, + ) + + ref_out, ref_lse = ref_attn(q, k, v) + + torch.testing.assert_close( + out, ref_out, atol=2e-2, rtol=1e-2 + ), f"{torch.max(torch.abs(out - ref_out))}" + torch.testing.assert_close( + lse, ref_lse, atol=2e-2, rtol=1e-2 + ), f"{torch.max(torch.abs(lse - ref_lse))}" + + +# @pytest.mark.parametrize("seq_lens", [[(1024, 1328)], +# [(1024, 1328), (1, 2048)], +# [(1025, 1328), (2, 2048)], +# [(1025, 2049), (2, 1281)], +# ]) +# @pytest.mark.parametrize("head_size", [128]) +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @torch.inference_mode() +# def test_sparse_attention_varlen( +# seq_lens, +# head_size, +# dtype, +# ) -> None: +# torch.set_default_device("cuda") +# torch.cuda.manual_seed_all(0) +# block_size_M = 64 +# block_size_N = 64 +# num_seqs = len(seq_lens) +# query_lens = [x[0] for x in seq_lens] +# kv_lens = [x[1] for x in seq_lens] +# num_heads = 1 +# query = torch.randn(sum(query_lens), +# num_heads, +# head_size, +# dtype=dtype) +# key = torch.randn(sum(kv_lens), +# num_heads, +# head_size, +# dtype=dtype) +# value = torch.randn_like(key) +# cu_query_lens = torch.tensor([0] + query_lens, +# dtype=torch.int32).cumsum(dim=0, +# dtype=torch.int32) +# cu_kv_lens = torch.tensor([0] + kv_lens, +# dtype=torch.int32).cumsum(dim=0, +# dtype=torch.int32) +# max_query_len = max(query_lens) +# max_kv_len = max(kv_lens) + +# NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M +# NNZ_S = 20 +# NNZ_V = 2048 +# batch_size = len(query_lens) + +# block_counts = [] +# column_counts = [] +# block_offsets = [] +# column_indices = [] +# for b in range(batch_size): +# block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) +# columns = kv_lens[b] - NNZ_S * block_size_N +# column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) +# block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S)) +# column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V)) +# block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS) +# column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS) +# block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) +# column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) +# out, lse = sparse_attn_varlen_func( +# query, +# key, +# value, +# block_count, +# block_offset, +# column_count, +# column_index, +# cu_seqlens_q=cu_query_lens, +# cu_seqlens_k=cu_kv_lens, +# max_seqlen_q=max_query_len, +# max_seqlen_k=max_kv_len, +# return_softmax_lse=True, +# ) + +# max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048 +# block_tables = torch.randint(0, +# 2048, +# (len(query_lens), max_num_blocks_per_seq), +# dtype=torch.int32) +# scale = head_size**-0.5 + +# ref_out, ref_lse, _ = ref_paged_attn( +# query, +# key, +# value, +# query_lens=query_lens, +# kv_lens=kv_lens, +# block_tables=block_tables, +# scale=scale +# ) + +# torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \ +# f"{torch.max(torch.abs(out - ref_out))}" +# torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ +# f"{torch.max(torch.abs(lse - ref_lse))}" + +if __name__ == "__main__": + pytest.main([__file__])