diff --git a/flashinfer/decode.py b/flashinfer/decode.py index 574f8a024c..9189d09767 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -51,6 +51,7 @@ _check_pos_encoding_mode, check_shape_dtype_device, _get_cache_alibi_slopes_buf, + _get_sink_buf, _get_cache_buf, _get_range_buf, _unpack_paged_kv_cache, @@ -242,6 +243,7 @@ def run_batch_decode( window_left: int, enable_pdl: bool, alibi_slopes: Optional[torch.Tensor], + maybe_s_aux: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -263,6 +265,7 @@ def run_batch_decode( window_left, enable_pdl, alibi_slopes, + maybe_s_aux, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale @@ -286,6 +289,7 @@ def _fake_run_batch_decode( window_left: int, enable_pdl: bool, alibi_slopes: Optional[torch.Tensor], + maybe_s_aux: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, @@ -384,6 +388,7 @@ def single_decode_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: Literal[True] = True, + sinks: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: ... @@ -403,6 +408,7 @@ def single_decode_with_kv_cache( rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, return_lse: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: r"""Decode attention with KV Cache for single request, return attention output. @@ -529,6 +535,7 @@ def single_decode_with_kv_cache( window_left, None, # packed_custom_mask _get_cache_alibi_slopes_buf(num_qo_heads, q.device), + sinks, # maybe_s_aux logits_soft_cap, sm_scale, None, # scale_q, not supported yet @@ -1330,7 +1337,7 @@ def run( self._kv_lens_buffer, page_size, self._max_kv_len, - sinks, + _get_sink_buf(sinks), ] self._cached_module.paged_run(*run_args) @@ -1364,6 +1371,7 @@ def run( else: run_args += [ _get_cache_alibi_slopes_buf(q.shape[1], q.device), + _get_sink_buf(sinks), logits_soft_cap, sm_scale, rope_scale, diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index 475acdcd1e..32e9c8aaa3 100644 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -467,8 +467,8 @@ def gen_single_decode_module( dtype_o, head_dim_qk, head_dim_vo, - ["maybe_alibi_slopes"], # additional_tensor_names - ["float"], # additional_tensor_dtypes + ["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names + ["float", "float"], # additional_tensor_dtypes [ "logits_soft_cap", "sm_scale", @@ -516,8 +516,12 @@ def gen_single_prefill_module( if backend == "fa2": assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" - additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] - additional_tensor_dtypes = ["uint8_t", "float"] + additional_tensor_names = [ + "maybe_custom_mask", + "maybe_alibi_slopes", + "maybe_s_aux", + ] + additional_tensor_dtypes = ["uint8_t", "float", "float"] additional_scalar_names = [ "logits_soft_cap", "sm_scale", @@ -760,8 +764,8 @@ def gen_batch_decode_module( dtype_idx, head_dim_qk, head_dim_vo, - ["maybe_alibi_slopes"], # additional_tensor_names - ["float"], # additional_tensor_dtypes + ["maybe_alibi_slopes", "maybe_s_aux"], # additional_tensor_names + ["float", "float"], # additional_tensor_dtypes [ "logits_soft_cap", "sm_scale", diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 49abe60897..a2a9a7e3aa 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -277,6 +277,7 @@ def run_single_prefill( window_left: int, maybe_packed_custom_mask: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_s_aux: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, scale_q: Optional[torch.Tensor], @@ -330,6 +331,7 @@ def run_single_prefill( window_left, maybe_packed_custom_mask, maybe_alibi_slopes, + maybe_s_aux, logits_soft_cap, sm_scale, 1.0 / rope_scale, # rope_rcp_scale @@ -350,6 +352,7 @@ def _fake_run_single_prefill( window_left: int, maybe_packed_custom_mask: Optional[torch.Tensor], maybe_alibi_slopes: Optional[torch.Tensor], + maybe_s_aux: Optional[torch.Tensor], logits_soft_cap: float, sm_scale: float, rope_scale: float, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 771d616380..b51ceae131 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -237,6 +237,23 @@ def _get_cache_alibi_slopes_buf( return buf +def _get_sink_buf( + sinks: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + """Convert sinks tensor to proper format for CUDA kernels. + + Args: + sinks: Optional tensor of shape [num_qo_heads] with sink values per head + + Returns: + Contiguous float32 tensor or None if sinks is None + """ + if sinks is None: + return None + # Ensure it's float32 and contiguous as expected by CUDA kernels + return sinks.to(torch.float32).contiguous() + + def canonicalize_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: if isinstance(dtype, str): return getattr(torch, dtype) diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index cb0ad3be7c..16fb4b97e5 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -355,6 +355,14 @@ __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params par // sync local state of all warps inside a threadblock sync_state(variant, st_local, reinterpret_cast(smem), smem_md, tx, ty, tz); + // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed + if constexpr (variant.use_softmax) { + if (params.maybe_s_aux != nullptr) { + constexpr float LOG2_E = 1.4426950408889634f; // log2(e) + float s_aux_val = params.maybe_s_aux[qo_head_idx]; + st_local.d += math::ptx_exp2((s_aux_val - st_local.m) * LOG2_E); + } + } #pragma unroll for (size_t i = 0; i < vec_size; ++i) { st_local.o[i] = variant.OutputTransform(params, st_local.o[i], /*batch_idx=*/0, /*qo_idx=*/0, @@ -589,6 +597,14 @@ __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& param // sync local state of all warps inside a threadblock sync_state(variant, st, reinterpret_cast(smem), smem_md, tx, ty, tz); + // Add s_aux (learnable sink) contribution to softmax denominator after all tiles processed + if constexpr (variant.use_softmax) { + if (params.maybe_s_aux != nullptr) { + constexpr float LOG2_E = 1.4426950408889634f; // log2(e) + float s_aux_val = params.maybe_s_aux[qo_head_idx]; + st.d += math::ptx_exp2((s_aux_val - st.m) * LOG2_E); + } + } #pragma unroll for (size_t i = 0; i < vec_size; ++i) { st.o[i] = variant.OutputTransform(params, st.o[i], bx, /*qo_idx=*/0, qo_head_idx, st.m, st.d, diff --git a/include/flashinfer/attention/default_decode_params.cuh b/include/flashinfer/attention/default_decode_params.cuh index d06e46338c..df5c8321f5 100644 --- a/include/flashinfer/attention/default_decode_params.cuh +++ b/include/flashinfer/attention/default_decode_params.cuh @@ -37,6 +37,7 @@ struct SingleDecodeParams { DTypeO* o; float* lse; float* maybe_alibi_slopes; + float* maybe_s_aux; uint32_t kv_len; uint32_t num_qo_heads; uint32_t num_kv_heads; @@ -58,6 +59,7 @@ struct SingleDecodeParams { o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), + maybe_s_aux(nullptr), kv_len(0), num_qo_heads(0), num_kv_heads(0), @@ -84,6 +86,7 @@ struct SingleDecodeParams { o(o), lse(nullptr), maybe_alibi_slopes(maybe_alibi_slopes), + maybe_s_aux(nullptr), kv_len(seq_len), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads), @@ -118,6 +121,7 @@ struct BatchDecodeParams { DTypeO* o; float* lse; float* maybe_alibi_slopes; + float* maybe_s_aux; uint32_t padded_batch_size; uint32_t num_qo_heads; IdType q_stride_n; @@ -142,6 +146,7 @@ struct BatchDecodeParams { o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), + maybe_s_aux(nullptr), padded_batch_size(0), num_qo_heads(0), q_stride_n(0), @@ -170,6 +175,7 @@ struct BatchDecodeParams { o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), + maybe_s_aux(nullptr), padded_batch_size(0), num_qo_heads(num_qo_heads), q_stride_n(q_stride_n), diff --git a/include/flashinfer/attention/default_prefill_params.cuh b/include/flashinfer/attention/default_prefill_params.cuh index 2e857fcc72..cb270cedcb 100644 --- a/include/flashinfer/attention/default_prefill_params.cuh +++ b/include/flashinfer/attention/default_prefill_params.cuh @@ -38,6 +38,7 @@ struct SinglePrefillParams { DTypeO* o; float* lse; float* maybe_alibi_slopes; + float* maybe_s_aux; uint_fastdiv group_size; uint32_t qo_len; uint32_t kv_len; @@ -66,6 +67,7 @@ struct SinglePrefillParams { o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), + maybe_s_aux(nullptr), group_size(), qo_len(0), kv_len(0), @@ -86,7 +88,7 @@ struct SinglePrefillParams { partition_kv(false) {} __host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, - DTypeO* o, float* lse, float* maybe_alibi_slopes, + DTypeO* o, float* lse, float* maybe_alibi_slopes, float* maybe_s_aux, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim, @@ -99,6 +101,7 @@ struct SinglePrefillParams { o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), + maybe_s_aux(maybe_s_aux), group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads), @@ -146,6 +149,7 @@ struct BatchPrefillRaggedParams { DTypeO* o; float* lse; float* maybe_alibi_slopes; + float* maybe_s_aux; uint_fastdiv group_size; uint32_t num_qo_heads; uint32_t num_kv_heads; @@ -190,6 +194,7 @@ struct BatchPrefillRaggedParams { o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), + maybe_s_aux(nullptr), group_size(), num_qo_heads(0), num_kv_heads(0), @@ -224,9 +229,9 @@ struct BatchPrefillRaggedParams { IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset, DTypeO* o, float* lse, float* maybe_alibi_slopes, - uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, - uint32_t kv_stride_h, int32_t window_left, + float* maybe_s_aux, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta) : q(q), @@ -241,6 +246,7 @@ struct BatchPrefillRaggedParams { o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), + maybe_s_aux(maybe_s_aux), group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads), @@ -296,6 +302,7 @@ struct BatchPrefillPagedParams { DTypeO* o; float* lse; float* maybe_alibi_slopes; + float* maybe_s_aux; uint_fastdiv group_size; uint32_t num_qo_heads; IdType q_stride_n; @@ -332,6 +339,7 @@ struct BatchPrefillPagedParams { o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), + maybe_s_aux(nullptr), group_size(), num_qo_heads(0), q_stride_n(0), @@ -361,9 +369,9 @@ struct BatchPrefillPagedParams { uint8_t* maybe_custom_mask, IdType* q_indptr, IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset, DTypeO* o, float* lse, float* maybe_alibi_slopes, - uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h, - int32_t window_left, float logits_soft_cap, float sm_scale, - float rope_scale, float rope_theta) + float* maybe_s_aux, uint32_t num_qo_heads, IdType q_stride_n, + IdType q_stride_h, int32_t window_left, float logits_soft_cap, + float sm_scale, float rope_scale, float rope_theta) : q(q), paged_kv(paged_kv), maybe_custom_mask(maybe_custom_mask), @@ -373,6 +381,7 @@ struct BatchPrefillPagedParams { o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), + maybe_s_aux(maybe_s_aux), group_size(num_qo_heads / paged_kv.num_heads), num_qo_heads(num_qo_heads), q_stride_n(q_stride_n), diff --git a/include/flashinfer/attention/variants.cuh b/include/flashinfer/attention/variants.cuh index c333a36e69..89ce87d565 100644 --- a/include/flashinfer/attention/variants.cuh +++ b/include/flashinfer/attention/variants.cuh @@ -90,6 +90,16 @@ struct DefaultAttention : AttentionVariantBase { } return mask; }) + + REGISTER_M_D_UPDATE(params, kv_tile_idx, qo_head_idx, m, d, scale, { + if constexpr (use_softmax) { + if (params.maybe_s_aux != nullptr) { + constexpr float LOG2_E = 1.4426950408889634f; // log2(e) + float s_aux_val = params.maybe_s_aux[qo_head_idx]; + d += math::ptx_exp2((s_aux_val - m) * LOG2_E); + } + } + }) }; }; // namespace flashinfer diff --git a/tests/attention/test_decode_sink_attention.py b/tests/attention/test_decode_sink_attention.py new file mode 100755 index 0000000000..9460a9b7d8 --- /dev/null +++ b/tests/attention/test_decode_sink_attention.py @@ -0,0 +1,412 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import math + +import pytest +import torch +from tests.test_helpers.sink_attention_reference import sink_attention_unified + +import flashinfer +from flashinfer.utils import has_flashinfer_jit_cache + + +def sink_attention_decode_ref( + q: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sink: torch.Tensor, + window_left: int, + sm_scale: float, +) -> torch.Tensor: + """Reference implementation for decode mode sink attention.""" + return sink_attention_unified( + q, + k_cache, + v_cache, + sink, + window_left, + causal=True, + sm_scale=sm_scale, + mode="incremental", + ) + + +@pytest.fixture( + autouse=not has_flashinfer_jit_cache(), + scope="module", +) +def warmup_jit(): + """Warmup JIT cache for decode kernels.""" + # This will be built on-demand during tests + yield + + +@pytest.mark.parametrize("batch_size", [1, 4, 16]) +@pytest.mark.parametrize("kv_len", [32, 128, 512]) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads", + [ + (8, 8), # MHA: equal heads + (32, 8), # GQA: 4:1 ratio + (32, 32), # MHA: equal heads + ], +) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("window_left", [-1]) # Only test without sliding window +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +def test_batch_decode_with_sink_attention( + batch_size, + kv_len, + num_qo_heads, + num_kv_heads, + head_dim, + window_left, + page_size, + kv_layout, +): + """Test batch decode with sink attention support.""" + torch.manual_seed(42) + device = torch.device("cuda:0") + dtype = torch.bfloat16 + + sm_scale = 1.0 / math.sqrt(head_dim) + + # Create query tensor: [batch_size, num_qo_heads, head_dim] + q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=dtype, device=device) + + # Create KV cache in paged format + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + + if kv_layout == "NHD": + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + else: + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device=device) + kv_data = kv_data_fp32.to(dtype) + + # Create page indices and metadata + kv_indptr = ( + torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device=device, dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device=device + ) + + # Create sink tensor: [num_qo_heads] float32 + # Sink values should be on similar scale to logits (QK^T * sm_scale) + # For typical logits, use smaller range to match expected scale + sinks = torch.randn(num_qo_heads, device=device, dtype=torch.float32) * 0.5 + + # Create workspace buffer + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device=device) + + # Test with FlashInfer + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=dtype, + q_data_type=dtype, + sm_scale=sm_scale, + window_left=window_left, + ) + + out = wrapper.run(q, kv_data, sinks=sinks) + + # Create reference implementation + # Convert paged KV cache to regular format for reference + k_cache_ref = torch.zeros( + batch_size, kv_len, num_kv_heads, head_dim, dtype=dtype, device=device + ) + v_cache_ref = torch.zeros( + batch_size, kv_len, num_kv_heads, head_dim, dtype=dtype, device=device + ) + + for b in range(batch_size): + page_start = b * num_pages_per_seq + for p in range(num_pages_per_seq): + page_idx = page_start + p + token_start = p * page_size + token_end = min(token_start + page_size, kv_len) + actual_page_len = token_end - token_start + + if kv_layout == "NHD": + k_cache_ref[b, token_start:token_end] = kv_data_fp32[ + page_idx, 0, :actual_page_len + ].to(dtype) + v_cache_ref[b, token_start:token_end] = kv_data_fp32[ + page_idx, 1, :actual_page_len + ].to(dtype) + else: + k_cache_ref[b, token_start:token_end] = ( + kv_data_fp32[page_idx, 0, :, :actual_page_len] + .transpose(0, 1) + .to(dtype) + ) + v_cache_ref[b, token_start:token_end] = ( + kv_data_fp32[page_idx, 1, :, :actual_page_len] + .transpose(0, 1) + .to(dtype) + ) + + # Compute reference output + out_ref = sink_attention_decode_ref( + q, k_cache_ref, v_cache_ref, sinks, window_left, sm_scale + ) + + # Compare results + # bfloat16 may have slightly larger numerical differences due to lower precision, + # differences in order of operations between reference and CUDA kernel, and + # GQA scenarios where multiple query heads share KV heads + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=3.5e-2) + + +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("kv_len", [128]) +@pytest.mark.parametrize("num_qo_heads", [32]) +@pytest.mark.parametrize("num_kv_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_batch_decode_without_sink_attention( + batch_size, kv_len, num_qo_heads, num_kv_heads, head_dim +): + """Test that decode without sinks matches decode with zero sinks.""" + torch.manual_seed(42) + device = torch.device("cuda:0") + dtype = torch.bfloat16 + + sm_scale = 1.0 / math.sqrt(head_dim) + page_size = 16 + kv_layout = "NHD" + + # Create query tensor + q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=dtype, device=device) + + # Create KV cache + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + kv_data = torch.randn(*kv_shape, dtype=dtype, device=device) + + kv_indptr = ( + torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device=device, dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device=device + ) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=dtype, + q_data_type=dtype, + sm_scale=sm_scale, + ) + + # Test without sinks + out_no_sinks = wrapper.run(q, kv_data, sinks=None) + + # Test with zero sinks (should match no sinks) + zero_sinks = torch.zeros(num_qo_heads, device=device, dtype=torch.float32) + out_zero_sinks = wrapper.run(q, kv_data, sinks=zero_sinks) + + # Results should be very close (zero sinks should be equivalent to no sinks) + # Note: Even when skipping zero sinks, there may be small numerical differences + # due to code path differences and floating point precision + # bfloat16 has lower precision, so allow slightly larger tolerance + torch.testing.assert_close(out_no_sinks, out_zero_sinks, rtol=5e-3, atol=5e-3) + + +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("kv_len", [64]) +@pytest.mark.parametrize("num_qo_heads", [16]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("head_dim", [64]) +def test_batch_decode_sink_attention_gqa( + batch_size, kv_len, num_qo_heads, num_kv_heads, head_dim +): + """Test sink attention with grouped query attention (GQA).""" + torch.manual_seed(42) + device = torch.device("cuda:0") + dtype = torch.bfloat16 + + sm_scale = 1.0 / math.sqrt(head_dim) + page_size = 16 + kv_layout = "NHD" + + # Create query tensor with more heads than KV + q = torch.randn(batch_size, num_qo_heads, head_dim, dtype=dtype, device=device) + + # Create KV cache with fewer heads + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + kv_data = torch.randn(*kv_shape, dtype=dtype, device=device) + + kv_indptr = ( + torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device=device, dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device=device + ) + + # Sink tensor should have num_qo_heads elements + sinks = torch.rand(num_qo_heads, device=device, dtype=torch.float32) * 5.0 + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device=device) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + pos_encoding_mode="NONE", + data_type=dtype, + q_data_type=dtype, + sm_scale=sm_scale, + ) + + # This should work with GQA + out = wrapper.run(q, kv_data, sinks=sinks) + + # Basic sanity check: output should have correct shape + assert out.shape == (batch_size, num_qo_heads, head_dim) + assert out.dtype == dtype + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + +@pytest.mark.parametrize("kv_len", [32, 128, 512]) +@pytest.mark.parametrize( + "num_qo_heads,num_kv_heads", + [ + (8, 8), # MHA: equal heads + (16, 8), # GQA: 2:1 ratio + (32, 8), # GQA: 4:1 ratio + (32, 32), # MHA: equal heads + ], +) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("kv_layout", ["NHD", "HND"]) +def test_single_decode_sink_attention_tensor_cores( + kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout +): + """Test sink attention with single decode using tensor cores (prefill template).""" + torch.manual_seed(42) + device = torch.device("cuda:0") + dtype = torch.bfloat16 + + sm_scale = 1.0 / math.sqrt(head_dim) + window_left = -1 # No sliding window + + # Create query tensor + q = torch.randn(num_qo_heads, head_dim, dtype=dtype, device=device) + + # Create KV cache based on layout + if kv_layout == "NHD": + k = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + v = torch.randn(kv_len, num_kv_heads, head_dim, dtype=dtype, device=device) + else: # HND + k = torch.randn(num_kv_heads, kv_len, head_dim, dtype=dtype, device=device) + v = torch.randn(num_kv_heads, kv_len, head_dim, dtype=dtype, device=device) + + # Sink tensor should have num_qo_heads elements + # Sink values should be on similar scale to logits (QK^T * sm_scale) + sinks = torch.randn(num_qo_heads, device=device, dtype=torch.float32) * 0.5 + + # Test with tensor cores enabled (uses prefill template) + out = flashinfer.single_decode_with_kv_cache( + q, + k, + v, + kv_layout=kv_layout, + pos_encoding_mode="NONE", + use_tensor_cores=True, + sm_scale=sm_scale, + sinks=sinks, + ) + + # Basic sanity check: output should have correct shape + assert out.shape == (num_qo_heads, head_dim) + assert out.dtype == dtype + assert not torch.isnan(out).any() + assert not torch.isinf(out).any() + + # Validate against reference implementation + # Convert to batch format for reference (add batch dimension) + q_batch = q.unsqueeze(0) # [1, num_qo_heads, head_dim] + + # Convert KV cache to reference format [batch_size, kv_len, num_kv_heads, head_dim] + if kv_layout == "NHD": + k_cache_ref = k.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim] + v_cache_ref = v.unsqueeze(0) # [1, kv_len, num_kv_heads, head_dim] + else: # HND -> transpose to NHD + k_cache_ref = k.transpose(0, 1).unsqueeze( + 0 + ) # [1, kv_len, num_kv_heads, head_dim] + v_cache_ref = v.transpose(0, 1).unsqueeze( + 0 + ) # [1, kv_len, num_kv_heads, head_dim] + + # Compute reference output + out_ref = sink_attention_decode_ref( + q_batch, k_cache_ref, v_cache_ref, sinks, window_left, sm_scale + ) + + # Remove batch dimension from reference output + out_ref = out_ref.squeeze(0) # [num_qo_heads, head_dim] + + # Compare results + # bfloat16 may have slightly larger numerical differences due to lower precision, + # differences in order of operations between reference and CUDA kernel, and + # GQA scenarios where multiple query heads share KV heads + torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=3.5e-2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])