Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions flashinfer/cute_dsl/attention/mla_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def can_implement(
lse_dtype,
mma_qk_tiler_mn: Tuple[int, int],
mma_pv_tiler_mn: Tuple[int, int],
split_kv: int,
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
Expand All @@ -175,7 +174,7 @@ def can_implement(
return False
if is_var_split_kv and not is_var_seq:
return False
if H > 128 or (H < 128 and split_kv != 1):
if H > 128:
Comment thread
saltyminty marked this conversation as resolved.
return False
if S < 1 or S > 4:
return False
Expand All @@ -197,7 +196,6 @@ def can_implement_fp8(
lse_dtype,
mma_qk_tiler_mn: Tuple[int, int],
mma_pv_tiler_mn: Tuple[int, int],
split_kv: int,
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
Expand All @@ -220,7 +218,7 @@ def can_implement_fp8(
return False
if is_var_split_kv and not is_var_seq:
return False
if H > 128 or (H < 128 and split_kv != 1):
if H > 128:
return False
if S <= 0 or S > 4:
return False
Expand Down
18 changes: 11 additions & 7 deletions flashinfer/cute_dsl/attention/mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,22 +744,28 @@ def initialize_workspace(
"""Initialize workspace tensors acc_o and acc_lse for split-KV."""
acc_o, acc_lse = None, None
if cutlass.const_expr(workspace is not None):
workspace_H = cutlass.max(H, cutlass.Int32(128))
align = 256 // self.q_dtype.width
acc_o_layout = cute.make_layout(
(H, split_kv, D, S, B),
(workspace_H, split_kv, D, S, B),
stride=(
cute.assume(split_kv * D, align),
cute.assume(D, align),
1,
cute.assume(split_kv * H * D, align),
cute.assume(H * split_kv * S * D, align),
cute.assume(split_kv * workspace_H * D, align),
cute.assume(workspace_H * split_kv * S * D, align),
),
)
acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype)
acc_o = cute.make_tensor(acc_o_iter, acc_o_layout)
acc_lse_layout = cute.make_layout(
(H, split_kv, S, B),
stride=(split_kv, 1, H * split_kv, H * split_kv * S),
(workspace_H, split_kv, S, B),
stride=(
split_kv,
1,
workspace_H * split_kv,
workspace_H * split_kv * S,
),
)
acc_lse_iter = cute.recast_ptr(
workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8,
Expand Down Expand Up @@ -803,7 +809,6 @@ def can_implement(
lse_dtype: Type[cutlass.Numeric],
mma_qk_tiler_mn: Tuple[int, int],
mma_pv_tiler_mn: Tuple[int, int],
split_kv: int,
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
Expand All @@ -822,7 +827,6 @@ def can_implement(
lse_dtype,
mma_qk_tiler_mn,
mma_pv_tiler_mn,
split_kv,
is_persistent,
is_var_seq,
is_var_split_kv,
Expand Down
18 changes: 11 additions & 7 deletions flashinfer/cute_dsl/attention/mla_decode_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,22 +762,28 @@ def initialize_workspace(
"""Initialize workspace tensors acc_o and acc_lse for split-KV."""
acc_o, acc_lse = None, None
if cutlass.const_expr(workspace is not None):
workspace_H = cutlass.max(H, cutlass.Int32(128))
align = 256 // self.q_dtype.width
acc_o_layout = cute.make_layout(
(H, split_kv, D, S, B),
(workspace_H, split_kv, D, S, B),
stride=(
cute.assume(split_kv * D, align),
cute.assume(D, align),
1,
cute.assume(split_kv * H * D, align),
cute.assume(H * split_kv * S * D, align),
cute.assume(split_kv * workspace_H * D, align),
cute.assume(workspace_H * split_kv * S * D, align),
),
)
acc_o_iter = cute.recast_ptr(workspace.iterator, dtype=acc_dtype)
acc_o = cute.make_tensor(acc_o_iter, acc_o_layout)
acc_lse_layout = cute.make_layout(
(H, split_kv, S, B),
stride=(split_kv, 1, H * split_kv, H * split_kv * S),
(workspace_H, split_kv, S, B),
stride=(
split_kv,
1,
workspace_H * split_kv,
workspace_H * split_kv * S,
),
)
acc_lse_iter = cute.recast_ptr(
workspace.iterator + cute.cosize(acc_o_layout) * acc_dtype.width // 8,
Expand Down Expand Up @@ -821,7 +827,6 @@ def can_implement(
lse_dtype: Type[cutlass.Numeric],
mma_qk_tiler_mn: Tuple[int, int],
mma_pv_tiler_mn: Tuple[int, int],
split_kv: int,
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
Expand All @@ -840,7 +845,6 @@ def can_implement(
lse_dtype,
mma_qk_tiler_mn,
mma_pv_tiler_mn,
split_kv,
is_persistent,
is_var_seq,
is_var_split_kv,
Expand Down
5 changes: 4 additions & 1 deletion flashinfer/cute_dsl/attention/scheduler/mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,7 @@ def mla_get_workspace_size(
"""Get workspace size in bytes for split-KV MLA decode."""
if split_kv == 1:
return 0
return B * H * S * split_kv * (D + 1) * acc_dtype_width // 8
# Decode packs heads into a physical 128-wide MMA-M tile. For H < 128,
# split-KV partials can still touch the padded head lanes before reduction.
workspace_heads = max(H, 128)
return B * workspace_heads * S * split_kv * (D + 1) * acc_dtype_width // 8
17 changes: 0 additions & 17 deletions flashinfer/cute_dsl/attention/wrappers/batch_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def _check_can_implement(
cutlass.Float32,
mma_qk_tiler_mn,
mma_pv_tiler_mn,
1, # split_kv (runtime, use 1 to pass the H<128 check)
is_persistent,
is_var_seq,
is_var_split_kv,
Expand Down Expand Up @@ -592,11 +591,6 @@ def run(
B, q_len, H, self._kv_lora_rank, max_active_blocks
)

if H < 128 and split_kv != 1:
raise ValueError(
f"num_heads={H} < 128 requires split_kv==1, got split_kv={split_kv}"
)

# Prepare workspace
is_workspace_size_zero = workspace_size == 0
if is_workspace_size_zero:
Expand Down Expand Up @@ -778,23 +772,12 @@ def cute_dsl_mla_decode(
# Runtime validation
if max_seq_len <= 0:
raise ValueError(f"max_seq_len must be > 0, got {max_seq_len}")
if H < 128 and H != 1:
raise ValueError(
f"cute_dsl_mla_decode requires num_heads >= 128 (or 1 for reduction), got {H}"
)

# Cached split_kv and workspace_size computation
max_active_blocks = get_num_sm(query.device)
split_kv, workspace_size = _get_split_kv_and_workspace_size(
B, q_len, H, kv_lora_rank, max_active_blocks
)

if H < 128 and split_kv != 1:
raise ValueError(
f"cute_dsl_mla_decode: num_heads={H} < 128 requires split_kv==1, "
f"got split_kv={split_kv}"
)

# Prepare workspace
assert workspace_buffer.dtype == torch.int8, (
f"workspace_buffer must be torch.int8, got {workspace_buffer.dtype}"
Expand Down
21 changes: 13 additions & 8 deletions tests/attention/test_cute_dsl_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,11 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size,
torch.testing.assert_close(out, ref_out_cast, atol=1e-2, rtol=1e-2)


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("batch_size", [1, 4, 128])
@pytest.mark.parametrize("seq_len_k", [128, 512])
@pytest.mark.parametrize("num_heads", [128, 64])
def test_cute_dsl_mla_decode_via_api(
batch_size, seq_len_k, page_size=128, enable_pdl=False
batch_size, seq_len_k, num_heads, page_size=128, enable_pdl=False
):
"""Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend."""
skip_if_unsupported()
Expand All @@ -279,7 +280,6 @@ def test_cute_dsl_mla_decode_via_api(
torch.manual_seed(42)
device = torch.device("cuda")

num_heads = 128
latent_dim = 512
rope_dim = 64
q_len = 1
Expand Down Expand Up @@ -324,6 +324,7 @@ def test_cute_dsl_mla_decode_via_api(
)

assert out.shape == (batch_size, q_len, num_heads, latent_dim)
assert torch.isfinite(out).all(), "cute-dsl MLA decode produced non-finite values"


@pytest.mark.parametrize("batch_size", [1, 4])
Expand Down Expand Up @@ -398,8 +399,11 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512, 2048])
@pytest.mark.parametrize("page_size", [64, 128])
@pytest.mark.parametrize("num_heads", [128, 64])
@pytest.mark.parametrize("enable_pdl", [False])
def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl):
def test_cute_dsl_mla_decode_fp8(
batch_size, seq_len_k, page_size, num_heads, enable_pdl
):
"""Test FP8 MLA decode kernel against FP32 reference."""
skip_if_unsupported()

Expand All @@ -408,7 +412,6 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl):
torch.manual_seed(42)
device = torch.device("cuda")

num_heads = 128
latent_dim = 512
rope_dim = 64
q_len = 1
Expand Down Expand Up @@ -457,6 +460,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl):

assert out.dtype == torch.bfloat16
assert out.shape == (batch_size, q_len, num_heads, latent_dim)
assert torch.isfinite(out).all(), "FP8 cute-dsl MLA decode produced non-finite"

# Reference: compute in FP32 using FP8 values dequantized to FP32
kv_flat = kv_cache.reshape(-1, D_qk).to(torch.float32)
Expand Down Expand Up @@ -969,7 +973,8 @@ def _make_fp8_mla_inputs(
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512])
@pytest.mark.parametrize("page_size", [64])
def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size):
@pytest.mark.parametrize("num_heads", [128, 64])
def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size, num_heads):
"""Test FP8 MLA decode with ALiBi variant."""
skip_if_unsupported()

Expand All @@ -979,11 +984,10 @@ def test_cute_dsl_mla_decode_fp8_alibi(batch_size, seq_len_k, page_size):
from flashinfer.cute_dsl.attention.fusion.variant import ALiBiAttention

torch.manual_seed(42)
num_heads = 128
latent_dim = 512
rope_dim = 64
query, kv_cache, block_tables, seq_lens, workspace_buffer = _make_fp8_mla_inputs(
batch_size, seq_len_k, page_size
batch_size, seq_len_k, page_size, num_heads=num_heads
)
softmax_scale = 1.0 / (latent_dim**0.5)
output_scale = 1.0
Expand Down Expand Up @@ -1035,6 +1039,7 @@ def alibi_score_mod(score, batch_idx, qo_idx, kv_idx, head_idx):
page_size,
score_mod_fn=alibi_score_mod,
)
assert torch.isfinite(out).all(), "FP8 ALiBi MLA decode produced non-finite"
torch.testing.assert_close(
out.to(torch.float32), ref_out.to(torch.float32), atol=0.1, rtol=0.1
)
Expand Down
8 changes: 5 additions & 3 deletions tests/attention/test_trtllm_gen_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,9 +818,11 @@ def test_trtllm_batch_decode_mla(
pytest.skip("XQA MLA does not support smaller MLA dimensions yet.")
if backend == "xqa" and layer_dimensions.num_heads != 128:
pytest.skip("XQA MLA only supports 128 query heads (head_group_ratio=128)")
if backend == "cute-dsl" and layer_dimensions.num_heads < 128:
pytest.skip("cute-dsl MLA requires num_heads >= 128")

if (
backend == "cute-dsl"
and layer_dimensions.head_dimensions == smaller_mla_dimensions
):
pytest.skip("cute-dsl MLA requires 512 latent dim and 64 rope dim")
trtllm_batch_decode_mla(
layer_dimensions,
batch_size,
Expand Down
Loading