Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,9 @@ def trtllm_batch_decode_with_kv_cache_mla(

return out
elif backend == "cute-dsl":
enable_pdl = (
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
)
cc = get_compute_capability(query.device)
if cc[0] < 10:
raise RuntimeError(
Expand All @@ -823,10 +826,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support sparse_mla_top_k"
)
if enable_pdl is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support enable_pdl"
)
if skip_softmax_threshold_scale_factor is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support skip_softmax_threshold_scale_factor"
Expand All @@ -850,6 +849,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
output_scale=bmm2_scale,
out=out,
is_var_seq=is_var_seq,
enable_pdl=enable_pdl,
)
else:
raise ValueError(f"Backend {backend} not supported")
Expand Down
6 changes: 6 additions & 0 deletions flashinfer/mla/cute_dsl/mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def _get_compiled_mla_kernel(
is_var_split_kv: bool,
skip_correction_threshold: float = 0.0,
is_workspace_size_zero: bool = False,
enable_pdl: bool = False,
) -> Callable:
"""Compile and cache an MLA decode kernel.

Expand Down Expand Up @@ -156,6 +157,7 @@ def _get_compiled_mla_kernel(
is_persistent=is_persistent,
is_var_seq=is_var_seq,
is_var_split_kv=is_var_split_kv,
enable_pdl=enable_pdl,
)

# All dimensions as sym_int β€” this matches the original kernel's use of
Expand Down Expand Up @@ -294,6 +296,7 @@ def cute_dsl_mla_decode(
out: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
is_var_seq: bool = True,
enable_pdl: bool = False,
) -> torch.Tensor:
"""CuTe DSL MLA decode kernel for Blackwell SM100.

Expand Down Expand Up @@ -335,6 +338,8 @@ def cute_dsl_mla_decode(
Whether the sequence length is variable.
If True, the sequence length is variable.
Otherwise,the sequence length is fixed for all the requests in the batch.
enable_pdl : bool
Whether to use PDL.

Returns
-------
Expand Down Expand Up @@ -466,6 +471,7 @@ def cute_dsl_mla_decode(
is_var_split_kv=is_var_split_kv,
skip_correction_threshold=skip_correction_threshold,
is_workspace_size_zero=is_workspace_size_zero,
enable_pdl=enable_pdl,
)

# Call the kernel
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/mla/cute_dsl/mla_decode_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def __init__(
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
enable_pdl: bool,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
):
"""Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel.

Expand All @@ -193,6 +194,8 @@ def __init__(
:type is_var_seq: bool
:param is_var_split_kv: Whether to use variable split KV
:type is_var_split_kv: bool
:param enable_pdl: Whether to use PDL
:type enable_pdl: bool
"""

self.latent_dim = 512
Expand All @@ -207,6 +210,7 @@ def __init__(
self.page_size = page_size
self.is_var_seq = is_var_seq
self.is_var_split_kv = is_var_split_kv
self.enable_pdl = enable_pdl
self.cluster_shape_mnk = (2, 1, 1)
self.use_2cta_instrs = True
# When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2),
Expand Down Expand Up @@ -709,6 +713,7 @@ class SplitKVKernelSharedStorage:
smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined]
stream=stream,
min_blocks_per_mp=1,
use_pdl=self.enable_pdl,
)
Comment thread
Observer007 marked this conversation as resolved.
if cutlass.const_expr(acc_o is not None):
self.reduction_kernel(
Expand All @@ -725,6 +730,7 @@ class SplitKVKernelSharedStorage:
smem=MAX_SPLITS * self.acc_dtype.width // 8,
stream=stream,
min_blocks_per_mp=1,
use_pdl=self.enable_pdl,
)

@cute.jit
Expand Down Expand Up @@ -979,6 +985,9 @@ def split_kv_kernel(
#
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk)

if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_wait()

# ///////////////////////////////////////////////////////////////////////////////
# Load warps, including page table and data tensors
# ///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1187,6 +1196,8 @@ def split_kv_kernel(

tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_launch_dependents()

# ///////////////////////////////////////////////////////////////////////////////
# Compute warp
Expand Down Expand Up @@ -1366,6 +1377,8 @@ def reduction_kernel(
lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype)
smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS))

if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_wait()
gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]]
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == 0:
Expand Down Expand Up @@ -1428,6 +1441,8 @@ def reduction_kernel(
for j in cutlass.range_constexpr(elements_per_thread):
element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps
mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j]
if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_launch_dependents()
return

@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions flashinfer/mla/cute_dsl/mla_decode_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
is_persistent: bool,
is_var_seq: bool,
is_var_split_kv: bool,
enable_pdl: bool,
):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel.

Expand All @@ -189,6 +190,8 @@ def __init__(
:type is_var_seq: bool
:param is_var_split_kv: Whether to use variable split KV
:type is_var_split_kv: bool
:param enable_pdl: Whether to use PDL
:type enable_pdl: bool
"""

self.latent_dim = 512
Expand All @@ -203,6 +206,7 @@ def __init__(
self.page_size = page_size
self.is_var_seq = is_var_seq
self.is_var_split_kv = is_var_split_kv
self.enable_pdl = enable_pdl
self.cluster_shape_mnk = (2, 1, 1)
self.use_2cta_instrs = True
# When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2),
Expand Down Expand Up @@ -771,6 +775,7 @@ class SplitKVKernelSharedStorage:
smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined]
stream=stream,
min_blocks_per_mp=1,
use_pdl=self.enable_pdl,
)
if cutlass.const_expr(acc_o is not None):
self.reduction_kernel(
Expand All @@ -787,6 +792,7 @@ class SplitKVKernelSharedStorage:
smem=MAX_SPLITS * self.acc_dtype.width // 8,
stream=stream,
min_blocks_per_mp=1,
use_pdl=self.enable_pdl,
)

@cute.jit
Expand Down Expand Up @@ -1050,6 +1056,9 @@ def split_kv_kernel(
#
pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk)

if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_wait()

# ///////////////////////////////////////////////////////////////////////////////
# Load warps, including page table and data tensors
# ///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -1251,6 +1260,8 @@ def split_kv_kernel(

tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_launch_dependents()

# ///////////////////////////////////////////////////////////////////////////////
# Compute warp
Expand Down Expand Up @@ -1428,6 +1439,8 @@ def reduction_kernel(
lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype)
smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS))

if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_wait()
gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]]
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
if warp_idx == 0:
Expand Down Expand Up @@ -1490,6 +1503,8 @@ def reduction_kernel(
for j in cutlass.range_constexpr(elements_per_thread):
element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps
mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j]
if cutlass.const_expr(self.enable_pdl):
cute.arch.griddepcontrol_launch_dependents()
return

@staticmethod
Expand Down
23 changes: 18 additions & 5 deletions tests/attention/test_cute_dsl_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def torch_reference_mla(
@pytest.mark.parametrize("page_size", [128])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("q_len", [1, 2])
def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len):
@pytest.mark.parametrize("enable_pdl", [True, False])
def test_cute_dsl_mla_decode_fp16(
batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl
):
"""Test FP16/BF16 MLA decode kernel."""
skip_if_unsupported()

Expand Down Expand Up @@ -158,6 +161,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len
softmax_scale=softmax_scale,
output_scale=output_scale,
is_var_seq=False,
enable_pdl=enable_pdl,
)

# Reference
Expand Down Expand Up @@ -187,7 +191,9 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len

@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512])
def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128):
def test_cute_dsl_mla_decode_variable_seq_len(
batch_size, seq_len_k, page_size=128, enable_pdl=False
):
"""Test MLA decode with variable sequence lengths across the batch."""
skip_if_unsupported()

Expand Down Expand Up @@ -241,6 +247,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1
softmax_scale=softmax_scale,
output_scale=output_scale,
is_var_seq=True,
enable_pdl=enable_pdl,
)

# Reference
Expand Down Expand Up @@ -268,7 +275,9 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1

@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512])
def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128):
def test_cute_dsl_mla_decode_via_api(
batch_size, seq_len_k, page_size=128, enable_pdl=False
):
"""Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend."""
skip_if_unsupported()

Expand Down Expand Up @@ -318,14 +327,16 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128):
bmm2_scale=1.0,
backend="cute-dsl",
is_var_seq=False,
enable_pdl=enable_pdl,
)

assert out.shape == (batch_size, q_len, num_heads, latent_dim)


@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512])
def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64):
@pytest.mark.parametrize("enable_pdl", [True, False])
def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64):
"""Test cute-dsl backend output matches trtllm-gen backend output."""
skip_if_unsupported()

Expand Down Expand Up @@ -394,7 +405,8 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seq_len_k", [128, 512])
@pytest.mark.parametrize("page_size", [128])
def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size):
@pytest.mark.parametrize("enable_pdl", [True, False])
def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl):
"""Test FP8 MLA decode kernel against FP32 reference."""
skip_if_unsupported()

Expand Down Expand Up @@ -447,6 +459,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size):
max_seq_len=seq_len_k,
softmax_scale=softmax_scale,
output_scale=output_scale,
enable_pdl=enable_pdl,
)

assert out.dtype == torch.bfloat16
Expand Down
Loading