From 8ba15c87e4a5d66a3b1917a9d938c79bd9f60c0d Mon Sep 17 00:00:00 2001 From: mingyangw Date: Tue, 5 May 2026 13:20:23 -0700 Subject: [PATCH 1/4] enable cutedsl d64, requires splikt_kv=1 --- .../cute_dsl/attention/wrappers/batch_mla.py | 5 ---- tests/attention/test_cute_dsl_mla_decode.py | 19 ++++++++++++--- tests/attention/test_trtllm_gen_mla.py | 24 +++++++++++++++++-- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index b7aabc3629..e7aa3ad47f 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -778,11 +778,6 @@ def cute_dsl_mla_decode( # Runtime validation if max_seq_len <= 0: raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") - if H < 128 and H != 1: - raise ValueError( - f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}" - ) - # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) split_kv, workspace_size = _get_split_kv_and_workspace_size( diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 3abe867595..f89d2f7d25 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -30,6 +30,18 @@ def skip_if_unsupported(): pytest.skip("CuTe DSL not available") +def skip_if_invalid_small_head_split_kv(batch_size, q_len, num_heads, device): + if num_heads >= 128: + return + + from flashinfer.cute_dsl.utils import get_num_sm + + max_active_blocks = get_num_sm(device) + split_kv = min(max(1, max_active_blocks // batch_size // (q_len * 2)), 32) + if split_kv != 1: + pytest.skip("CuTe DSL MLA with num_heads < 128 requires split_kv == 1") + + def torch_reference_mla( q_nope, q_rope, @@ -266,10 +278,11 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("batch_size", [1, 4, 128]) @pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("num_heads", [128, 64]) def test_cute_dsl_mla_decode_via_api( - batch_size, seq_len_k, page_size=128, enable_pdl=False + batch_size, seq_len_k, num_heads, page_size=128, enable_pdl=False ): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported() @@ -279,10 +292,10 @@ def test_cute_dsl_mla_decode_via_api( torch.manual_seed(42) device = torch.device("cuda") - num_heads = 128 latent_dim = 512 rope_dim = 64 q_len = 1 + skip_if_invalid_small_head_split_kv(batch_size, q_len, num_heads, device) softmax_scale = 1.0 / (latent_dim**0.5) D_qk = latent_dim + rope_dim diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index c1cf3d8a50..8da4995721 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -16,6 +16,14 @@ workspace_size = 128 * 1024 * 1024 +def get_mla_split_kv_simplified(batch_size: int, q_len: int, device) -> int: + from flashinfer.cute_dsl.utils import get_num_sm + + max_active_blocks = get_num_sm(device) + blocks_per_batch = max(1, max_active_blocks // batch_size // (q_len * 2)) + return min(blocks_per_batch, 32) + + def generate_sparse_indices( batch_size: int, q_len_per_request: int, @@ -818,8 +826,20 @@ def test_trtllm_batch_decode_mla( pytest.skip("XQA MLA does not support smaller MLA dimensions yet.") if backend == "xqa" and layer_dimensions.num_heads != 128: pytest.skip("XQA MLA only supports 128 query heads (head_group_ratio=128)") - if backend == "cute-dsl" and layer_dimensions.num_heads < 128: - pytest.skip("cute-dsl MLA requires num_heads >= 128") + if ( + backend == "cute-dsl" + and layer_dimensions.head_dimensions == smaller_mla_dimensions + ): + pytest.skip("cute-dsl MLA requires 512 latent dim and 64 rope dim") + if ( + backend == "cute-dsl" + and layer_dimensions.num_heads < 128 + and get_mla_split_kv_simplified( + batch_size, q_len_per_request, torch.device("cuda") + ) + != 1 + ): + pytest.skip("cute-dsl MLA with num_heads < 128 requires split_kv == 1") trtllm_batch_decode_mla( layer_dimensions, From a4a517ace95cddf57d1956a2e43f07632b178dd3 Mon Sep 17 00:00:00 2001 From: mingyangw Date: Tue, 5 May 2026 14:26:54 -0700 Subject: [PATCH 2/4] enable split_kv for h64 by padding to 128 CuTeDSL MLA decode uses a 128-wide physical head tile, so H=64 split-KV needs partial-output workspace pitched to that physical width before reduction. Keep logical output and reduction over H while allocating/pitching scratch with max(H, 128). Constraint: Existing SM100 CuTeDSL MLA decode config uses 128-wide QK/PV M tiles. Rejected: Leave split_kv disabled for H=64 | Kimi K2.5 needs the split-KV path. Confidence: medium Scope-risk: moderate Directive: Do not shrink split-KV workspace pitch below the physical MMA M tile without adding a true smaller-M kernel specialization. Tested: python3 -m py_compile on changed Python files; git diff --check; pre-commit hooks; remote SM100 public API smoke with H=64 and split_kv=32 before final expression cleanup. Not-tested: Full pytest sweep; remote smoke after replacing literal 128 with cutlass.max(H, 128). --- flashinfer/cute_dsl/attention/mla_config.py | 4 ++-- flashinfer/cute_dsl/attention/mla_decode.py | 16 +++++++++++----- .../cute_dsl/attention/mla_decode_fp8.py | 16 +++++++++++----- .../attention/scheduler/mla_persistent.py | 5 ++++- .../cute_dsl/attention/wrappers/batch_mla.py | 11 ----------- tests/attention/test_cute_dsl_mla_decode.py | 13 ------------- tests/attention/test_trtllm_gen_mla.py | 18 ------------------ 7 files changed, 28 insertions(+), 55 deletions(-) diff --git a/flashinfer/cute_dsl/attention/mla_config.py b/flashinfer/cute_dsl/attention/mla_config.py index ff2730d80d..405c744f02 100644 --- a/flashinfer/cute_dsl/attention/mla_config.py +++ b/flashinfer/cute_dsl/attention/mla_config.py @@ -175,7 +175,7 @@ def can_implement( return False if is_var_split_kv and not is_var_seq: return False - if H > 128 or (H < 128 and split_kv != 1): + if H > 128: return False if S < 1 or S > 4: return False @@ -220,7 +220,7 @@ def can_implement_fp8( return False if is_var_split_kv and not is_var_seq: return False - if H > 128 or (H < 128 and split_kv != 1): + if H > 128: return False if S <= 0 or S > 4: return False diff --git a/flashinfer/cute_dsl/attention/mla_decode.py b/flashinfer/cute_dsl/attention/mla_decode.py index c454a9a43f..429c26c855 100644 --- a/flashinfer/cute_dsl/attention/mla_decode.py +++ b/flashinfer/cute_dsl/attention/mla_decode.py @@ -744,22 +744,28 @@ def initialize_workspace( """Initialize workspace tensors acc_o and acc_lse for split-KV.""" acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): + workspace_H = cutlass.max(H, cutlass.Int32(128)) align = 256 // self.q_dtype.width acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), + (workspace_H, split_kv, D, S, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, diff --git a/flashinfer/cute_dsl/attention/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/mla_decode_fp8.py index dcf2e9ee5c..53b2869c64 100644 --- a/flashinfer/cute_dsl/attention/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/mla_decode_fp8.py @@ -762,22 +762,28 @@ def initialize_workspace( """Initialize workspace tensors acc_o and acc_lse for split-KV.""" acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): + workspace_H = cutlass.max(H, cutlass.Int32(128)) align = 256 // self.q_dtype.width acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), + (workspace_H, split_kv, D, S, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, diff --git a/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py index ff10dcd5d5..2e62ed4fec 100644 --- a/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py +++ b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py @@ -288,4 +288,7 @@ def mla_get_workspace_size( """Get workspace size in bytes for split-KV MLA decode.""" if split_kv == 1: return 0 - return B * H * S * split_kv * (D + 1) * acc_dtype_width // 8 + # Decode packs heads into a physical 128-wide MMA-M tile. For H < 128, + # split-KV partials can still touch the padded head lanes before reduction. + workspace_heads = max(H, 128) + return B * workspace_heads * S * split_kv * (D + 1) * acc_dtype_width // 8 diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index e7aa3ad47f..da148c2241 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -592,11 +592,6 @@ def run( B, q_len, H, self._kv_lora_rank, max_active_blocks ) - if H < 128 and split_kv != 1: - raise ValueError( - f"num_heads={H} < 128 requires split_kv==1, got split_kv={split_kv}" - ) - # Prepare workspace is_workspace_size_zero = workspace_size == 0 if is_workspace_size_zero: @@ -784,12 +779,6 @@ def cute_dsl_mla_decode( B, q_len, H, kv_lora_rank, max_active_blocks ) - if H < 128 and split_kv != 1: - raise ValueError( - f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, " - f"got split_kv={split_kv}" - ) - # Prepare workspace assert workspace_buffer.dtype == torch.int8, ( f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index f89d2f7d25..dd40ec3ac1 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -30,18 +30,6 @@ def skip_if_unsupported(): pytest.skip("CuTe DSL not available") -def skip_if_invalid_small_head_split_kv(batch_size, q_len, num_heads, device): - if num_heads >= 128: - return - - from flashinfer.cute_dsl.utils import get_num_sm - - max_active_blocks = get_num_sm(device) - split_kv = min(max(1, max_active_blocks // batch_size // (q_len * 2)), 32) - if split_kv != 1: - pytest.skip("CuTe DSL MLA with num_heads < 128 requires split_kv == 1") - - def torch_reference_mla( q_nope, q_rope, @@ -295,7 +283,6 @@ def test_cute_dsl_mla_decode_via_api( latent_dim = 512 rope_dim = 64 q_len = 1 - skip_if_invalid_small_head_split_kv(batch_size, q_len, num_heads, device) softmax_scale = 1.0 / (latent_dim**0.5) D_qk = latent_dim + rope_dim diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index 8da4995721..bac935ea62 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -16,14 +16,6 @@ workspace_size = 128 * 1024 * 1024 -def get_mla_split_kv_simplified(batch_size: int, q_len: int, device) -> int: - from flashinfer.cute_dsl.utils import get_num_sm - - max_active_blocks = get_num_sm(device) - blocks_per_batch = max(1, max_active_blocks // batch_size // (q_len * 2)) - return min(blocks_per_batch, 32) - - def generate_sparse_indices( batch_size: int, q_len_per_request: int, @@ -831,16 +823,6 @@ def test_trtllm_batch_decode_mla( and layer_dimensions.head_dimensions == smaller_mla_dimensions ): pytest.skip("cute-dsl MLA requires 512 latent dim and 64 rope dim") - if ( - backend == "cute-dsl" - and layer_dimensions.num_heads < 128 - and get_mla_split_kv_simplified( - batch_size, q_len_per_request, torch.device("cuda") - ) - != 1 - ): - pytest.skip("cute-dsl MLA with num_heads < 128 requires split_kv == 1") - trtllm_batch_decode_mla( layer_dimensions, batch_size, From 74e44c8a8583513f89a017647c3c3472b0889331 Mon Sep 17 00:00:00 2001 From: mingyangw Date: Thu, 7 May 2026 11:21:42 -0700 Subject: [PATCH 3/4] review feedback: finite-output check, test coverage for H64, remove split_kv from can_implement Address review feedback by strengthening the H64 regression coverage and removing split_kv from the capability-check signatures now that split_kv no longer gates support eligibility. Constraint: Keep runtime split_kv handling unchanged; this only affects can_implement eligibility checks. Confidence: high Scope-risk: narrow Tested: pre-commit run --files flashinfer/cute_dsl/attention/mla_config.py flashinfer/cute_dsl/attention/mla_decode.py flashinfer/cute_dsl/attention/mla_decode_fp8.py flashinfer/cute_dsl/attention/wrappers/batch_mla.py tests/attention/test_cute_dsl_mla_decode.py Tested: remote SM100 focused pytest, 3 passed in 11.09s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-comment-fixes-focused-v2.log Not-tested: full test_cute_dsl_mla_decode.py rerun after review-feedback patch --- flashinfer/cute_dsl/attention/mla_config.py | 2 -- flashinfer/cute_dsl/attention/mla_decode.py | 2 -- flashinfer/cute_dsl/attention/mla_decode_fp8.py | 2 -- .../cute_dsl/attention/wrappers/batch_mla.py | 1 - tests/attention/test_cute_dsl_mla_decode.py | 14 +++++++++----- 5 files changed, 9 insertions(+), 12 deletions(-) diff --git a/flashinfer/cute_dsl/attention/mla_config.py b/flashinfer/cute_dsl/attention/mla_config.py index 405c744f02..6e6517bbe3 100644 --- a/flashinfer/cute_dsl/attention/mla_config.py +++ b/flashinfer/cute_dsl/attention/mla_config.py @@ -152,7 +152,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -197,7 +196,6 @@ def can_implement_fp8( lse_dtype, mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, diff --git a/flashinfer/cute_dsl/attention/mla_decode.py b/flashinfer/cute_dsl/attention/mla_decode.py index 429c26c855..7d2d1c9d67 100644 --- a/flashinfer/cute_dsl/attention/mla_decode.py +++ b/flashinfer/cute_dsl/attention/mla_decode.py @@ -809,7 +809,6 @@ def can_implement( lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -828,7 +827,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, - split_kv, is_persistent, is_var_seq, is_var_split_kv, diff --git a/flashinfer/cute_dsl/attention/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/mla_decode_fp8.py index 53b2869c64..0fdca7a211 100644 --- a/flashinfer/cute_dsl/attention/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/mla_decode_fp8.py @@ -827,7 +827,6 @@ def can_implement( lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -846,7 +845,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, - split_kv, is_persistent, is_var_seq, is_var_split_kv, diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index da148c2241..e3d203db22 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -96,7 +96,6 @@ def _check_can_implement( cutlass.Float32, mma_qk_tiler_mn, mma_pv_tiler_mn, - 1, # split_kv (runtime, use 1 to pass the H<128 check) is_persistent, is_var_seq, is_var_split_kv, diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index dd40ec3ac1..4f8328d49f 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -324,6 +324,7 @@ def test_cute_dsl_mla_decode_via_api( ) assert out.shape == (batch_size, q_len, num_heads, latent_dim) + assert torch.isfinite(out).all(), "cute-dsl MLA decode produced non-finite values" @pytest.mark.parametrize("batch_size", [1, 4]) @@ -398,8 +399,11 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [64, 128]) +@pytest.mark.parametrize("num_heads", [128, 64]) @pytest.mark.parametrize("enable_pdl", [False]) -def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): +def test_cute_dsl_mla_decode_fp8( + batch_size, seq_len_k, page_size, num_heads, enable_pdl +): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() @@ -408,7 +412,6 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): torch.manual_seed(42) device = torch.device("cuda") - num_heads = 128 latent_dim = 512 rope_dim = 64 q_len = 1 @@ -457,6 +460,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): assert out.dtype == torch.bfloat16 assert out.shape == (batch_size, q_len, num_heads, latent_dim) + assert torch.isfinite(out).all(), "FP8 cute-dsl MLA decode produced non-finite" # Reference: compute in FP32 using FP8 values dequantized to FP32 kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) @@ -969,7 +973,8 @@ def _make_fp8_mla_inputs( @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("page_size", [64]) -def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size): +@pytest.mark.parametrize("num_heads", [128, 64]) +def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size, num_heads): """Test FP8 MLA decode with ALiBi variant.""" skip_if_unsupported() @@ -979,11 +984,10 @@ def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size): from flashinfer.cute_dsl.attention.fusion.variant import ALiBiAttention torch.manual_seed(42) - num_heads = 128 latent_dim = 512 rope_dim = 64 query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs( - batch_size, seq_len_k, page_size + batch_size, seq_len_k, page_size, num_heads=num_heads ) softmax_scale = 1.0 / (latent_dim**0.5) output_scale = 1.0 From b4d193cde01671ad9b47e9c7a760b9e4d41b7a01 Mon Sep 17 00:00:00 2001 From: mingyangw Date: Thu, 7 May 2026 13:58:39 -0700 Subject: [PATCH 4/4] carrot Add the missing finite-output guard before the FP8 ALiBi MLA decode reference comparison so non-finite values fail explicitly before tolerance checks. Constraint: Preserve the review-requested commit subject. Confidence: high Scope-risk: narrow Tested: pre-commit run --files tests/attention/test_cute_dsl_mla_decode.py Tested: python3 -m py_compile tests/attention/test_cute_dsl_mla_decode.py Tested: remote SM100 focused pytest, 1 passed in 5.36s, log /home/scratch.mingyangw_gpu/flashinfer-3161-validation/logs/pr3235-alibi-finite-guard.log Not-tested: full test_cute_dsl_mla_decode.py rerun after this one-line guard --- tests/attention/test_cute_dsl_mla_decode.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 4f8328d49f..6147d10c86 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -1039,6 +1039,7 @@ def alibi_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): page_size, score_mod_fn=alibi_score_mod, ) + assert torch.isfinite(out).all(), "FP8 ALiBi MLA decode produced non-finite" torch.testing.assert_close( out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 )