diff --git a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py index b82e554259..5388e1b870 100644 --- a/flashinfer/gdn_kernels/gdn_decode_bf16_state.py +++ b/flashinfer/gdn_kernels/gdn_decode_bf16_state.py @@ -751,6 +751,8 @@ def gated_delta_rule_decode_kernel_seqlen1( value_head_idx = bidx % HV query_head_idx = value_head_idx // (HV // H) pool_batch_idx = gH_slot_indices[batch_idx] + if pool_batch_idx < 0: + pool_batch_idx = cutlass.Int32(0) smem = utils.SmemAllocator() @@ -1130,6 +1132,8 @@ def gated_delta_rule_decode_kernel_seqlen234_unified( value_head_idx = bidx % HV query_head_idx = value_head_idx // (HV // H) pool_batch_idx = gH_slot_indices[batch_idx] + if pool_batch_idx < 0: + pool_batch_idx = cutlass.Int32(0) warp_idx = tidx // 32 lane_idx = tidx % 32 @@ -1563,6 +1567,8 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk( query_head_idx = value_head_idx // (HV // H) v_row_base = v_chunk_idx * 32 pool_batch_idx = gH_slot_indices[batch_idx] + if pool_batch_idx < 0: + pool_batch_idx = cutlass.Int32(0) smem = utils.SmemAllocator() diff --git a/tests/gdn/test_decode_delta_rule.py b/tests/gdn/test_decode_delta_rule.py index 0d764a7fc2..7661f3016f 100644 --- a/tests/gdn/test_decode_delta_rule.py +++ b/tests/gdn/test_decode_delta_rule.py @@ -23,7 +23,6 @@ import torch import pytest -pytestmark = pytest.mark.skip(reason="Temporarily skipped due to CI failures.") try: from .reference_delta_rule import decode_delta_rule, verify_delta_rule @@ -203,6 +202,7 @@ def _test_decode_kernel_pretranspose( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_basic_pretranspose( @@ -368,6 +368,7 @@ def _test_decode_kernel_nontranspose( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_basic_nontranspose( @@ -512,6 +513,7 @@ def _test_decode_kernel_pretranspose_pool( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool( @@ -768,6 +770,7 @@ def _test_decode_kernel_pretranspose_pool_all_padding( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool_negative_indices( @@ -795,6 +798,7 @@ def test_decode_kernel_pretranspose_pool_negative_indices( @pytest.mark.parametrize("scale", [1.0]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_decode_kernel_pretranspose_pool_all_padding( @@ -819,6 +823,151 @@ def test_decode_kernel_pretranspose_pool_all_padding( ) +# ============================================================================ +# Test bf16 decode kernel with negative (padding) indices +# +# Verifies that the bf16 fast-path kernel handles negative indices correctly +# via the slot-0 null buffer pattern: negative indices are redirected to slot 0 +# inside the kernel. Valid slots must produce correct output and updated state; +# the kernel must not crash. +# ============================================================================ + + +def _test_decode_kernel_bf16_padding_indices( + batch_size: int, + num_q_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + padding_fraction: float = 0.5, + seed: int = 0, +): + """bf16 kernel with mixed negative/valid indices must not crash and must + produce correct output and state updates for valid slots.""" + _skip_if_not_sm90_or_later() + + random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + pool_size = batch_size * 2 + 1 # slot 0 = null buffer; real slots start at 1 + device = torch.device("cuda") + + with device: + q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=torch.bfloat16) + k = torch.nn.functional.normalize( + torch.randn(batch_size, 1, num_q_heads, head_size, dtype=torch.bfloat16), + p=2.0, + dim=-1, + ) + v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=torch.bfloat16) + + A_log = torch.randn(num_v_heads, dtype=torch.float32) * 0.1 + dt_bias = torch.randn(num_v_heads, dtype=torch.float32) * 0.1 + a = torch.randn(batch_size, 1, num_v_heads, dtype=torch.bfloat16) * 0.1 + b = torch.randn(batch_size, 1, num_v_heads, dtype=torch.bfloat16) + + # Slot 0 = null buffer (zeros); real slots start from 1 + pool = torch.zeros( + pool_size, num_v_heads, head_size, head_size, dtype=torch.bfloat16 + ) + pool[1:] = torch.randn( + pool_size - 1, num_v_heads, head_size, head_size, dtype=torch.bfloat16 + ) + + # Build indices: some slots are padding (-1), others map to real slots [1, pool_size) + indices = torch.arange(1, batch_size + 1, dtype=torch.int32, device=device) + mask = torch.rand(batch_size, device=device) < padding_fraction + if batch_size >= 2: + mask[0] = False # ensure at least one valid + mask[-1] = True # ensure at least one padding + indices[mask] = -1 + + valid_mask = indices >= 0 + + # ── Pool path under test ───────────────────────────────────────────────── + pool_under_test = pool.clone() + out_pool, _ = gated_delta_rule_decode_pretranspose( + q=q, + k=k, + v=v, + state=None, + A_log=A_log, + a=a, + dt_bias=dt_bias, + b=b, + scale=scale, + use_qk_l2norm=True, + initial_state=pool_under_test, + initial_state_indices=indices, + ) + torch.cuda.synchronize() + + # ── Direct-state reference for valid slots only ────────────────────────── + if valid_mask.any(): + valid_indices = indices[valid_mask] + gathered = pool[valid_indices].clone() + out_direct, updated = gated_delta_rule_decode_pretranspose( + q=q[valid_mask], + k=k[valid_mask], + v=v[valid_mask], + state=gathered, + A_log=A_log, + a=a[valid_mask], + dt_bias=dt_bias, + b=b[valid_mask], + scale=scale, + use_qk_l2norm=True, + ) + atol, rtol = 5e-3, 5e-3 + torch.testing.assert_close( + out_pool[valid_mask], out_direct, atol=atol, rtol=rtol + ) + torch.testing.assert_close( + pool_under_test[valid_indices], updated, atol=atol, rtol=rtol + ) + + # Non-selected real slots (slots 1..pool_size-1 not in valid_indices) must be untouched + used = indices[valid_mask].to(device) + unused_mask = torch.ones(pool_size, dtype=torch.bool, device=device) + unused_mask[used] = False + unused_mask[0] = False # slot 0 may be modified (null buffer), don't check it + torch.testing.assert_close( + pool_under_test[unused_mask], pool[unused_mask], atol=0.0, rtol=0.0 + ) + + # Slot 0 (null buffer) must have been written by padding slots. + # Without the kernel-level fix, padding slots do an OOB write to gH[-1] + # (before the pool base) leaving slot 0 untouched — this assertion catches that. + if mask.any(): + assert not torch.equal(pool_under_test[0], pool[0]), ( + "Slot 0 (null buffer) should have been updated by padding slots; " + "if it is unchanged the kernel fix is missing" + ) + + print( + f"✓ bf16 padding indices test passed " + f"(batch={batch_size}, valid={valid_mask.sum().item()}, padding={mask.sum().item()})" + ) + + +@pytest.mark.parametrize("scale", [1.0]) +@pytest.mark.parametrize("head_size", [128]) +@pytest.mark.parametrize("num_q_heads, num_v_heads", [(16, 32)]) +@pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) +def test_decode_kernel_bf16_padding_indices( + batch_size: int, + num_q_heads: int, + num_v_heads: int, + head_size: int, + scale: float, + seed: int = int(os.environ.get("SEED", "0")), +): + _test_decode_kernel_bf16_padding_indices( + batch_size, num_q_heads, num_v_heads, head_size, scale, seed=seed + ) + + # ============================================================================ # Test verify kernel with MTP version (Multiple Token Processing) # Reference: fp32 h state (default). @@ -1016,6 +1165,7 @@ def _test_verify_kernel_mtp( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_verify_kernel_mtp( @@ -1056,6 +1206,7 @@ def test_verify_kernel_mtp( @pytest.mark.parametrize("seq_len", [2, 3, 4, 5, 6, 7, 8]) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_mtp_fp32_state_with_cache_and_state_update( @@ -1268,6 +1419,7 @@ def _test_gdn_decode_klast_bf16_state_kernel( "num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)], ) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("dtype", ["bfloat16"]) def test_gdn_decode_klast_bf16_state_kernel( @@ -1299,6 +1451,7 @@ def test_gdn_decode_klast_bf16_state_kernel( ) +@pytest.mark.skip(reason="Temporarily skipped due to CI failures.") @pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) @pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("head_size", [128])