Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,9 @@ def trtllm_batch_decode_with_kv_cache_mla(

return out
elif backend == "cute-dsl":
enable_pdl = (
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
)
cc = get_compute_capability(query.device)
if cc[0] < 10:
raise RuntimeError(
Expand All @@ -823,10 +826,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support sparse_mla_top_k"
)
if enable_pdl is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support enable_pdl"
)
if skip_softmax_threshold_scale_factor is not None:
raise ValueError(
"cute-dsl backend (MLA decode kernel) does not support skip_softmax_threshold_scale_factor"
Expand All @@ -850,6 +849,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
output_scale=bmm2_scale,
out=out,
is_var_seq=is_var_seq,
enable_pdl=enable_pdl,
)
else:
raise ValueError(f"Backend {backend} not supported")
Expand Down
11 changes: 11 additions & 0 deletions flashinfer/mla/cute_dsl/mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import torch
from cutlass import Float32, Int32

from ...utils import device_support_pdl

from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16
from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8
from flashinfer.cute_dsl.utils import (
Expand Down Expand Up @@ -118,6 +120,7 @@ def _get_compiled_mla_kernel(
is_var_split_kv: bool,
skip_correction_threshold: float = 0.0,
is_workspace_size_zero: bool = False,
enable_pdl: bool = False,
) -> Callable:
"""Compile and cache an MLA decode kernel.

Expand Down Expand Up @@ -156,6 +159,7 @@ def _get_compiled_mla_kernel(
is_persistent=is_persistent,
is_var_seq=is_var_seq,
is_var_split_kv=is_var_split_kv,
enable_pdl=enable_pdl,
)

# All dimensions as sym_int β€” this matches the original kernel's use of
Expand Down Expand Up @@ -294,6 +298,7 @@ def cute_dsl_mla_decode(
out: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
is_var_seq: bool = True,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
"""CuTe DSL MLA decode kernel for Blackwell SM100.

Expand Down Expand Up @@ -335,6 +340,9 @@ def cute_dsl_mla_decode(
Whether the sequence length is variable.
If True, the sequence length is variable.
Otherwise,the sequence length is fixed for all the requests in the batch.
enable_pdl : Optional[bool], default=None
Whether to enable Programmatic Dependent Launch (PDL).
If None, auto-detects based on device capability.

Returns
-------
Expand Down Expand Up @@ -452,6 +460,8 @@ def cute_dsl_mla_decode(
is_var_split_kv=is_var_split_kv,
)

enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl

# Get compiled kernel (cached after first compile)
# Note: when is_workspace_size_zero is True, workspace_bytes is None and it will launch one kernel without workspace.
# Otherwise, workspace_bytes is not None and it will launch two kernels.
Expand All @@ -466,6 +476,7 @@ def cute_dsl_mla_decode(
is_var_split_kv=is_var_split_kv,
skip_correction_threshold=skip_correction_threshold,
is_workspace_size_zero=is_workspace_size_zero,
enable_pdl=enable_pdl,
)

# Call the kernel
Expand Down
Loading
Loading