diff --git a/flashinfer/cute_dsl/attention/mla_config.py b/flashinfer/cute_dsl/attention/mla_config.py index ff2730d80d..6e6517bbe3 100644 --- a/flashinfer/cute_dsl/attention/mla_config.py +++ b/flashinfer/cute_dsl/attention/mla_config.py @@ -152,7 +152,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -175,7 +174,7 @@ def can_implement( return False if is_var_split_kv and not is_var_seq: return False - if H > 128 or (H < 128 and split_kv != 1): + if H > 128: return False if S < 1 or S > 4: return False @@ -197,7 +196,6 @@ def can_implement_fp8( lse_dtype, mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -220,7 +218,7 @@ def can_implement_fp8( return False if is_var_split_kv and not is_var_seq: return False - if H > 128 or (H < 128 and split_kv != 1): + if H > 128: return False if S <= 0 or S > 4: return False diff --git a/flashinfer/cute_dsl/attention/mla_decode.py b/flashinfer/cute_dsl/attention/mla_decode.py index c454a9a43f..7d2d1c9d67 100644 --- a/flashinfer/cute_dsl/attention/mla_decode.py +++ b/flashinfer/cute_dsl/attention/mla_decode.py @@ -744,22 +744,28 @@ def initialize_workspace( """Initialize workspace tensors acc_o and acc_lse for split-KV.""" acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): + workspace_H = cutlass.max(H, cutlass.Int32(128)) align = 256 // self.q_dtype.width acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), + (workspace_H, split_kv, D, S, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, @@ -803,7 +809,6 @@ def can_implement( lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -822,7 +827,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, - split_kv, is_persistent, is_var_seq, is_var_split_kv, diff --git a/flashinfer/cute_dsl/attention/mla_decode_fp8.py b/flashinfer/cute_dsl/attention/mla_decode_fp8.py index dcf2e9ee5c..0fdca7a211 100644 --- a/flashinfer/cute_dsl/attention/mla_decode_fp8.py +++ b/flashinfer/cute_dsl/attention/mla_decode_fp8.py @@ -762,22 +762,28 @@ def initialize_workspace( """Initialize workspace tensors acc_o and acc_lse for split-KV.""" acc_o, acc_lse = None, None if cutlass.const_expr(workspace is not None): + workspace_H = cutlass.max(H, cutlass.Int32(128)) align = 256 // self.q_dtype.width acc_o_layout = cute.make_layout( - (H, split_kv, D, S, B), + (workspace_H, split_kv, D, S, B), stride=( cute.assume(split_kv * D, align), cute.assume(D, align), 1, - cute.assume(split_kv * H * D, align), - cute.assume(H * split_kv * S * D, align), + cute.assume(split_kv * workspace_H * D, align), + cute.assume(workspace_H * split_kv * S * D, align), ), ) acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype) acc_o = cute.make_tensor(acc_o_iter, acc_o_layout) acc_lse_layout = cute.make_layout( - (H, split_kv, S, B), - stride=(split_kv, 1, H * split_kv, H * split_kv * S), + (workspace_H, split_kv, S, B), + stride=( + split_kv, + 1, + workspace_H * split_kv, + workspace_H * split_kv * S, + ), ) acc_lse_iter = cute.recast_ptr( workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8, @@ -821,7 +827,6 @@ def can_implement( lse_dtype: Type[cutlass.Numeric], mma_qk_tiler_mn: Tuple[int, int], mma_pv_tiler_mn: Tuple[int, int], - split_kv: int, is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, @@ -840,7 +845,6 @@ def can_implement( lse_dtype, mma_qk_tiler_mn, mma_pv_tiler_mn, - split_kv, is_persistent, is_var_seq, is_var_split_kv, diff --git a/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py index ff10dcd5d5..2e62ed4fec 100644 --- a/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py +++ b/flashinfer/cute_dsl/attention/scheduler/mla_persistent.py @@ -288,4 +288,7 @@ def mla_get_workspace_size( """Get workspace size in bytes for split-KV MLA decode.""" if split_kv == 1: return 0 - return B * H * S * split_kv * (D + 1) * acc_dtype_width // 8 + # Decode packs heads into a physical 128-wide MMA-M tile. For H < 128, + # split-KV partials can still touch the padded head lanes before reduction. + workspace_heads = max(H, 128) + return B * workspace_heads * S * split_kv * (D + 1) * acc_dtype_width // 8 diff --git a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py index b7aabc3629..e3d203db22 100644 --- a/flashinfer/cute_dsl/attention/wrappers/batch_mla.py +++ b/flashinfer/cute_dsl/attention/wrappers/batch_mla.py @@ -96,7 +96,6 @@ def _check_can_implement( cutlass.Float32, mma_qk_tiler_mn, mma_pv_tiler_mn, - 1, # split_kv (runtime, use 1 to pass the H<128 check) is_persistent, is_var_seq, is_var_split_kv, @@ -592,11 +591,6 @@ def run( B, q_len, H, self._kv_lora_rank, max_active_blocks ) - if H < 128 and split_kv != 1: - raise ValueError( - f"num_heads={H} < 128 requires split_kv==1, got split_kv={split_kv}" - ) - # Prepare workspace is_workspace_size_zero = workspace_size == 0 if is_workspace_size_zero: @@ -778,23 +772,12 @@ def cute_dsl_mla_decode( # Runtime validation if max_seq_len <= 0: raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}") - if H < 128 and H != 1: - raise ValueError( - f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}" - ) - # Cached split_kv and workspace_size computation max_active_blocks = get_num_sm(query.device) split_kv, workspace_size = _get_split_kv_and_workspace_size( B, q_len, H, kv_lora_rank, max_active_blocks ) - if H < 128 and split_kv != 1: - raise ValueError( - f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, " - f"got split_kv={split_kv}" - ) - # Prepare workspace assert workspace_buffer.dtype == torch.int8, ( f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}" diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 3abe867595..6147d10c86 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -266,10 +266,11 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size, torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("batch_size", [1, 4, 128]) @pytest.mark.parametrize("seq_len_k", [128, 512]) +@pytest.mark.parametrize("num_heads", [128, 64]) def test_cute_dsl_mla_decode_via_api( - batch_size, seq_len_k, page_size=128, enable_pdl=False + batch_size, seq_len_k, num_heads, page_size=128, enable_pdl=False ): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported() @@ -279,7 +280,6 @@ def test_cute_dsl_mla_decode_via_api( torch.manual_seed(42) device = torch.device("cuda") - num_heads = 128 latent_dim = 512 rope_dim = 64 q_len = 1 @@ -324,6 +324,7 @@ def test_cute_dsl_mla_decode_via_api( ) assert out.shape == (batch_size, q_len, num_heads, latent_dim) + assert torch.isfinite(out).all(), "cute-dsl MLA decode produced non-finite values" @pytest.mark.parametrize("batch_size", [1, 4]) @@ -398,8 +399,11 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512, 2048]) @pytest.mark.parametrize("page_size", [64, 128]) +@pytest.mark.parametrize("num_heads", [128, 64]) @pytest.mark.parametrize("enable_pdl", [False]) -def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): +def test_cute_dsl_mla_decode_fp8( + batch_size, seq_len_k, page_size, num_heads, enable_pdl +): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() @@ -408,7 +412,6 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): torch.manual_seed(42) device = torch.device("cuda") - num_heads = 128 latent_dim = 512 rope_dim = 64 q_len = 1 @@ -457,6 +460,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): assert out.dtype == torch.bfloat16 assert out.shape == (batch_size, q_len, num_heads, latent_dim) + assert torch.isfinite(out).all(), "FP8 cute-dsl MLA decode produced non-finite" # Reference: compute in FP32 using FP8 values dequantized to FP32 kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32) @@ -969,7 +973,8 @@ def _make_fp8_mla_inputs( @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("page_size", [64]) -def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size): +@pytest.mark.parametrize("num_heads", [128, 64]) +def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size, num_heads): """Test FP8 MLA decode with ALiBi variant.""" skip_if_unsupported() @@ -979,11 +984,10 @@ def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size): from flashinfer.cute_dsl.attention.fusion.variant import ALiBiAttention torch.manual_seed(42) - num_heads = 128 latent_dim = 512 rope_dim = 64 query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs( - batch_size, seq_len_k, page_size + batch_size, seq_len_k, page_size, num_heads=num_heads ) softmax_scale = 1.0 / (latent_dim**0.5) output_scale = 1.0 @@ -1035,6 +1039,7 @@ def alibi_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx): page_size, score_mod_fn=alibi_score_mod, ) + assert torch.isfinite(out).all(), "FP8 ALiBi MLA decode produced non-finite" torch.testing.assert_close( out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1 ) diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index c1cf3d8a50..bac935ea62 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -818,9 +818,11 @@ def test_trtllm_batch_decode_mla( pytest.skip("XQA MLA does not support smaller MLA dimensions yet.") if backend == "xqa" and layer_dimensions.num_heads != 128: pytest.skip("XQA MLA only supports 128 query heads (head_group_ratio=128)") - if backend == "cute-dsl" and layer_dimensions.num_heads < 128: - pytest.skip("cute-dsl MLA requires num_heads >= 128") - + if ( + backend == "cute-dsl" + and layer_dimensions.head_dimensions == smaller_mla_dimensions + ): + pytest.skip("cute-dsl MLA requires 512 latent dim and 64 rope dim") trtllm_batch_decode_mla( layer_dimensions, batch_size,