From f4af5db6d2f595e585c7496cf2d1ad2577a605f3 Mon Sep 17 00:00:00 2001 From: jingzec Date: Fri, 27 Mar 2026 06:32:09 -0700 Subject: [PATCH 1/4] add pdl support for cute dsl mla decode kernel support --- flashinfer/mla/_core.py | 8 ++++---- flashinfer/mla/cute_dsl/mla_decode.py | 6 ++++++ flashinfer/mla/cute_dsl/mla_decode_fp16.py | 15 +++++++++++++++ flashinfer/mla/cute_dsl/mla_decode_fp8.py | 15 +++++++++++++++ tests/attention/test_cute_dsl_mla_decode.py | 17 ++++++++++++----- 5 files changed, 52 insertions(+), 9 deletions(-) diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index bca53627bf..d722abaeb6 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -798,6 +798,9 @@ def trtllm_batch_decode_with_kv_cache_mla( return out elif backend == "cute-dsl": + enable_pdl = ( + device_support_pdl(query.device) if enable_pdl is None else enable_pdl + ) cc = get_compute_capability(query.device) if cc[0] < 10: raise RuntimeError( @@ -823,10 +826,6 @@ def trtllm_batch_decode_with_kv_cache_mla( raise ValueError( "cute-dsl backend (MLA decode kernel) does not support sparse_mla_top_k" ) - if enable_pdl is not None: - raise ValueError( - "cute-dsl backend (MLA decode kernel) does not support enable_pdl" - ) if skip_softmax_threshold_scale_factor is not None: raise ValueError( "cute-dsl backend (MLA decode kernel) does not support skip_softmax_threshold_scale_factor" @@ -850,6 +849,7 @@ def trtllm_batch_decode_with_kv_cache_mla( output_scale=bmm2_scale, out=out, is_var_seq=is_var_seq, + enable_pdl=enable_pdl, ) else: raise ValueError(f"Backend {backend} not supported") diff --git a/flashinfer/mla/cute_dsl/mla_decode.py b/flashinfer/mla/cute_dsl/mla_decode.py index ad19eb821b..aa5b6a3723 100644 --- a/flashinfer/mla/cute_dsl/mla_decode.py +++ b/flashinfer/mla/cute_dsl/mla_decode.py @@ -118,6 +118,7 @@ def _get_compiled_mla_kernel( is_var_split_kv: bool, skip_correction_threshold: float = 0.0, is_workspace_size_zero: bool = False, + enable_pdl: bool = False, ) -> Callable: """Compile and cache an MLA decode kernel. @@ -156,6 +157,7 @@ def _get_compiled_mla_kernel( is_persistent=is_persistent, is_var_seq=is_var_seq, is_var_split_kv=is_var_split_kv, + enable_pdl=enable_pdl, ) # All dimensions as sym_int — this matches the original kernel's use of @@ -294,6 +296,7 @@ def cute_dsl_mla_decode( out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, + enable_pdl: bool = False, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -335,6 +338,8 @@ def cute_dsl_mla_decode( Whether the sequence length is variable. If True, the sequence length is variable. Otherwise,the sequence length is fixed for all the requests in the batch. + enable_pdl : bool + Whether to use PDL. Returns ------- @@ -466,6 +471,7 @@ def cute_dsl_mla_decode( is_var_split_kv=is_var_split_kv, skip_correction_threshold=skip_correction_threshold, is_workspace_size_zero=is_workspace_size_zero, + enable_pdl=enable_pdl, ) # Call the kernel diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py index d91385088f..a570eb9ee9 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp16.py @@ -170,6 +170,7 @@ def __init__( is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, + enable_pdl: bool, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. @@ -193,6 +194,8 @@ def __init__( :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split KV :type is_var_split_kv: bool + :param enable_pdl: Whether to use PDL + :type enable_pdl: bool """ self.latent_dim = 512 @@ -207,6 +210,7 @@ def __init__( self.page_size = page_size self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv + self.enable_pdl = enable_pdl self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), @@ -709,6 +713,7 @@ class SplitKVKernelSharedStorage: smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] stream=stream, min_blocks_per_mp=1, + use_pdl=self.enable_pdl, ) if cutlass.const_expr(acc_o is not None): self.reduction_kernel( @@ -725,6 +730,7 @@ class SplitKVKernelSharedStorage: smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, + use_pdl=self.enable_pdl, ) @cute.jit @@ -979,6 +985,9 @@ def split_kv_kernel( # pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + # /////////////////////////////////////////////////////////////////////////////// # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// @@ -1187,6 +1196,8 @@ def split_kv_kernel( tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() # /////////////////////////////////////////////////////////////////////////////// # Compute warp @@ -1366,6 +1377,8 @@ def reduction_kernel( lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 0: @@ -1428,6 +1441,8 @@ def reduction_kernel( for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() return @staticmethod diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index d0e5d83242..8d26879020 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -166,6 +166,7 @@ def __init__( is_persistent: bool, is_var_seq: bool, is_var_split_kv: bool, + enable_pdl: bool, ): """Initializes the configuration for a Blackwell Multi-Head Latent Attention (MLA) kernel. @@ -189,6 +190,8 @@ def __init__( :type is_var_seq: bool :param is_var_split_kv: Whether to use variable split KV :type is_var_split_kv: bool + :param enable_pdl: Whether to use PDL + :type enable_pdl: bool """ self.latent_dim = 512 @@ -203,6 +206,7 @@ def __init__( self.page_size = page_size self.is_var_seq = is_var_seq self.is_var_split_kv = is_var_split_kv + self.enable_pdl = enable_pdl self.cluster_shape_mnk = (2, 1, 1) self.use_2cta_instrs = True # When using 2 CTAs with m=128: warps 0-1 handle accumulation for first half [0, n/2), @@ -771,6 +775,7 @@ class SplitKVKernelSharedStorage: smem=SplitKVKernelSharedStorage.size_in_bytes(), # type: ignore[attr-defined] stream=stream, min_blocks_per_mp=1, + use_pdl=self.enable_pdl, ) if cutlass.const_expr(acc_o is not None): self.reduction_kernel( @@ -787,6 +792,7 @@ class SplitKVKernelSharedStorage: smem=MAX_SPLITS * self.acc_dtype.width // 8, stream=stream, min_blocks_per_mp=1, + use_pdl=self.enable_pdl, ) @cute.jit @@ -1050,6 +1056,9 @@ def split_kv_kernel( # pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mnk) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() + # /////////////////////////////////////////////////////////////////////////////// # Load warps, including page table and data tensors # /////////////////////////////////////////////////////////////////////////////// @@ -1251,6 +1260,8 @@ def split_kv_kernel( tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() # /////////////////////////////////////////////////////////////////////////////// # Compute warp @@ -1428,6 +1439,8 @@ def reduction_kernel( lse_scale_ptr = cute.recast_ptr(storage, dtype=self.acc_dtype) smem_lse_scale = cute.make_tensor(lse_scale_ptr, cute.make_layout(MAX_SPLITS)) + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_wait() gLSE = mAccLSE[blk_coord[0], None, blk_coord[1], blk_coord[2]] warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) if warp_idx == 0: @@ -1490,6 +1503,8 @@ def reduction_kernel( for j in cutlass.range_constexpr(elements_per_thread): element_idx = tidx + j * self.threads_per_warp * self.num_compute_warps mO[blk_coord[0], element_idx, blk_coord[1], blk_coord[2]] = rO[j] + if cutlass.const_expr(self.enable_pdl): + cute.arch.griddepcontrol_launch_dependents() return @staticmethod diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index d9427460f5..21f2083670 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -101,7 +101,8 @@ def torch_reference_mla( @pytest.mark.parametrize("page_size", [128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("q_len", [1, 2]) -def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len): +@pytest.mark.parametrize("enable_pdl", [True, False]) +def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl): """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() @@ -158,6 +159,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len softmax_scale=softmax_scale, output_scale=output_scale, is_var_seq=False, + enable_pdl=enable_pdl, ) # Reference @@ -187,7 +189,7 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128): +def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128, enable_pdl=False): """Test MLA decode with variable sequence lengths across the batch.""" skip_if_unsupported() @@ -241,6 +243,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 softmax_scale=softmax_scale, output_scale=output_scale, is_var_seq=True, + enable_pdl=enable_pdl, ) # Reference @@ -268,7 +271,7 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): +def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128, enable_pdl=False): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported() @@ -318,6 +321,7 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): bmm2_scale=1.0, backend="cute-dsl", is_var_seq=False, + enable_pdl=enable_pdl, ) assert out.shape == (batch_size, q_len, num_heads, latent_dim) @@ -325,7 +329,8 @@ def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128): @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64): +@pytest.mark.parametrize("enable_pdl", [True, False]) +def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, enable_pdl, page_size=64): """Test cute-dsl backend output matches trtllm-gen backend output.""" skip_if_unsupported() @@ -394,7 +399,8 @@ def test_cute_dsl_vs_trtllm_gen(batch_size, seq_len_k, page_size=64): @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) @pytest.mark.parametrize("page_size", [128]) -def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): +@pytest.mark.parametrize("enable_pdl", [True, False]) +def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size, enable_pdl): """Test FP8 MLA decode kernel against FP32 reference.""" skip_if_unsupported() @@ -447,6 +453,7 @@ def test_cute_dsl_mla_decode_fp8(batch_size, seq_len_k, page_size): max_seq_len=seq_len_k, softmax_scale=softmax_scale, output_scale=output_scale, + enable_pdl=enable_pdl, ) assert out.dtype == torch.bfloat16 From c66d55a1b174afbc4918c72993da20cee4a2ea4e Mon Sep 17 00:00:00 2001 From: jingzec Date: Fri, 27 Mar 2026 07:15:18 -0700 Subject: [PATCH 2/4] apply ruff --- tests/attention/test_cute_dsl_mla_decode.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/attention/test_cute_dsl_mla_decode.py b/tests/attention/test_cute_dsl_mla_decode.py index 21f2083670..8a5e572a96 100644 --- a/tests/attention/test_cute_dsl_mla_decode.py +++ b/tests/attention/test_cute_dsl_mla_decode.py @@ -102,7 +102,9 @@ def torch_reference_mla( @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("q_len", [1, 2]) @pytest.mark.parametrize("enable_pdl", [True, False]) -def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl): +def test_cute_dsl_mla_decode_fp16( + batch_size, seq_len_k, page_size, dtype, q_len, enable_pdl +): """Test FP16/BF16 MLA decode kernel.""" skip_if_unsupported() @@ -189,7 +191,9 @@ def test_cute_dsl_mla_decode_fp16(batch_size, seq_len_k, page_size, dtype, q_len @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=128, enable_pdl=False): +def test_cute_dsl_mla_decode_variable_seq_len( + batch_size, seq_len_k, page_size=128, enable_pdl=False +): """Test MLA decode with variable sequence lengths across the batch.""" skip_if_unsupported() @@ -271,7 +275,9 @@ def test_cute_dsl_mla_decode_variable_seq_len(batch_size, seq_len_k, page_size=1 @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("seq_len_k", [128, 512]) -def test_cute_dsl_mla_decode_via_api(batch_size, seq_len_k, page_size=128, enable_pdl=False): +def test_cute_dsl_mla_decode_via_api( + batch_size, seq_len_k, page_size=128, enable_pdl=False +): """Test MLA decode via the trtllm_batch_decode_with_kv_cache_mla API with cute-dsl backend.""" skip_if_unsupported() From d73238710f26870d29496bbb1171d3030494c21c Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 27 Mar 2026 23:43:05 -0700 Subject: [PATCH 3/4] fix: pass missing enable_pdl arg to MLA kernel constructors The BlackwellMultiHeadLatentAttentionForward{FP8,FP16} constructors require enable_pdl but the run() harnesses weren't passing it (mypy call-arg error). Also follow public API convention in cute_dsl_mla_decode by defaulting to Optional[bool] = None with device_support_pdl() auto-detection. Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/mla/cute_dsl/mla_decode.py | 11 ++++++++--- flashinfer/mla/cute_dsl/mla_decode_fp16.py | 9 +++++++++ flashinfer/mla/cute_dsl/mla_decode_fp8.py | 9 +++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/flashinfer/mla/cute_dsl/mla_decode.py b/flashinfer/mla/cute_dsl/mla_decode.py index aa5b6a3723..1887e4e25c 100644 --- a/flashinfer/mla/cute_dsl/mla_decode.py +++ b/flashinfer/mla/cute_dsl/mla_decode.py @@ -28,6 +28,8 @@ import torch from cutlass import Float32, Int32 +from ...utils import device_support_pdl + from .mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 from .mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 from flashinfer.cute_dsl.utils import ( @@ -296,7 +298,7 @@ def cute_dsl_mla_decode( out: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = None, is_var_seq: bool = True, - enable_pdl: bool = False, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: """CuTe DSL MLA decode kernel for Blackwell SM100. @@ -338,8 +340,9 @@ def cute_dsl_mla_decode( Whether the sequence length is variable. If True, the sequence length is variable. Otherwise,the sequence length is fixed for all the requests in the batch. - enable_pdl : bool - Whether to use PDL. + enable_pdl : Optional[bool], default=None + Whether to enable Programmatic Dependent Launch (PDL). + If None, auto-detects based on device capability. Returns ------- @@ -457,6 +460,8 @@ def cute_dsl_mla_decode( is_var_split_kv=is_var_split_kv, ) + enable_pdl = device_support_pdl(query.device) if enable_pdl is None else enable_pdl + # Get compiled kernel (cached after first compile) # Note: when is_workspace_size_zero is True, workspace_bytes is None and it will launch one kernel without workspace. # Otherwise, workspace_bytes is not None and it will launch two kernels. diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py index a570eb9ee9..dfbc98048a 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp16.py @@ -3580,6 +3580,7 @@ def run( iterations: int, skip_ref_check: bool, use_cold_l2: bool, + enable_pdl: bool = False, **kwargs, ): """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. @@ -3958,6 +3959,7 @@ def create_workspace( is_persistent, is_var_seq, is_var_split_kv, + enable_pdl, ) # Get current CUDA stream from PyTorch @@ -4433,6 +4435,12 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: help="Use cold L2 cache", ) + parser.add_argument( + "--enable_pdl", + action="store_true", + help="Enable PDL", + ) + args = parser.parse_args() run( @@ -4461,6 +4469,7 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: args.iterations, args.skip_ref_check, args.use_cold_l2, + args.enable_pdl, ) print("PASS") diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index 8d26879020..242faf3db9 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -3550,6 +3550,7 @@ def run( iterations: int, skip_ref_check: bool, use_cold_l2: bool, + enable_pdl: bool = False, **kwargs, ): """Execute Multi-Head Latent Attention (MLA) on Blackwell architecture and validate results. @@ -3929,6 +3930,7 @@ def create_workspace( is_persistent, is_var_seq, is_var_split_kv, + enable_pdl, ) # Get current CUDA stream from PyTorch @@ -4403,6 +4405,12 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: help="Use cold L2 cache", ) + parser.add_argument( + "--enable_pdl", + action="store_true", + help="Enable PDL", + ) + args = parser.parse_args() run( @@ -4431,6 +4439,7 @@ def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: args.iterations, args.skip_ref_check, args.use_cold_l2, + args.enable_pdl, ) print("PASS") From eaed46bb1bf5b711489f95d4e073d7b133e1dd6d Mon Sep 17 00:00:00 2001 From: yzh119 Date: Fri, 27 Mar 2026 23:47:18 -0700 Subject: [PATCH 4/4] remove __main__ blocks from mla_decode_fp8.py and mla_decode_fp16.py These standalone test/benchmark harnesses are superseded by the public API in mla_decode.py and proper tests in tests/. Also removes the now-unused argparse import. Co-Authored-By: Claude Opus 4.6 (1M context) --- flashinfer/mla/cute_dsl/mla_decode_fp16.py | 227 +-------------------- flashinfer/mla/cute_dsl/mla_decode_fp8.py | 226 +------------------- 2 files changed, 2 insertions(+), 451 deletions(-) diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp16.py b/flashinfer/mla/cute_dsl/mla_decode_fp16.py index dfbc98048a..df18a414fe 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp16.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp16.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import argparse + import math from typing import Type, Tuple, Optional from types import SimpleNamespace @@ -4248,228 +4248,3 @@ def generate_tensors(): ) return avg_time_us # Return execution time in microseconds - - -if __name__ == "__main__": - - def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: - try: - return tuple(int(x.strip()) for x in s.split(",")) - except ValueError: - raise argparse.ArgumentTypeError( # noqa: B904 - "Invalid format. Expected comma-separated integers." - ) - - def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: - ret = parse_comma_separated_ints(s) - if len(ret) != 2: - raise argparse.ArgumentTypeError( - "Invalid format. Expected 2 comma-separated integers." - ) - return (ret[0], ret[1]) # type: ignore[return-value] - - parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") - - parser.add_argument( - "--in_dtype", - type=cutlass.dtype, - default=cutlass.Float16, - help="Input data type", - ) - - parser.add_argument( - "--out_dtype", - type=cutlass.dtype, - default=cutlass.Float16, - help="Output data type", - ) - - parser.add_argument( - "--acc_dtype", - type=cutlass.dtype, - default=cutlass.Float32, - help="Accumulator data type", - ) - - parser.add_argument( - "--lse_dtype", - type=cutlass.dtype, - default=cutlass.Float32, - help="LSE data type", - ) - parser.add_argument( - "--mma_qk_tiler_mn", - type=parse_mma_tiler, - default=(128, 128), - help="MMA tile shape (H, K)", - ) - parser.add_argument( - "--mma_pv_tiler_mn", - type=parse_mma_tiler, - default=(128, 256), - help="MMA tile shape (H, D)", - ) - - parser.add_argument( - "--is_persistent", - action="store_true", - help="Is persistent", - ) - - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size", - ) - - parser.add_argument( - "--seq_len_q", - type=int, - default=1, - help="Sequence length of Q", - ) - - parser.add_argument( - "--seq_len_k", - type=int, - default=128, - help="Sequence length of K/V", - ) - - parser.add_argument( - "--num_heads", - type=int, - default=128, - help="Number of heads of Q", - ) - - parser.add_argument( - "--latent_dim", - type=int, - default=512, - help="Latent dimension of Q/C", - ) - - parser.add_argument( - "--rope_dim", - type=int, - default=64, - help="Rope dimension of Q/C", - ) - - parser.add_argument( - "--is_var_seq", - action="store_true", - help="Use variable length of sequence length or not", - ) - - parser.add_argument( - "--is_var_split_kv", - action="store_true", - help="Use variable length of split kv or not", - ) - - parser.add_argument( - "--page_size", - type=int, - default=128, - help="Page size of page table", - ) - - parser.add_argument( - "--split_kv", - type=int, - default=-1, - help="Split KV setting", - ) - - parser.add_argument( - "--softmax_scale", - type=float, - default=0.0416, - help="Scaling factor to scale softmax", - ) - - parser.add_argument( - "--output_scale", - type=float, - default=1.0, - help="Scaling factor to scale output", - ) - - parser.add_argument( - "--skip_correction_threshold", - type=float, - default=0.0, - help="Skip correction threshold", - ) - - parser.add_argument( - "--tolerance", type=float, default=1e-02, help="Tolerance for validation" - ) - - parser.add_argument( - "--warmup_iterations", - type=int, - default=0, - help="Number of iterations for warmup", - ) - - parser.add_argument( - "--iterations", - type=int, - default=1, - help="Number of iterations after warmup", - ) - - parser.add_argument( - "--skip_ref_check", - action="store_true", - help="Skip reference check", - ) - - parser.add_argument( - "--use_cold_l2", - action="store_true", - help="Use cold L2 cache", - ) - - parser.add_argument( - "--enable_pdl", - action="store_true", - help="Enable PDL", - ) - - args = parser.parse_args() - - run( - args.batch_size, - args.seq_len_q, - args.seq_len_k, - args.num_heads, - args.latent_dim, - args.rope_dim, - args.in_dtype, - args.out_dtype, - args.acc_dtype, - args.lse_dtype, - args.mma_qk_tiler_mn, - args.mma_pv_tiler_mn, - args.split_kv, - args.is_persistent, - args.is_var_seq, - args.is_var_split_kv, - args.page_size, - args.softmax_scale, - args.output_scale, - args.skip_correction_threshold, - args.tolerance, - args.warmup_iterations, - args.iterations, - args.skip_ref_check, - args.use_cold_l2, - args.enable_pdl, - ) - - print("PASS") diff --git a/flashinfer/mla/cute_dsl/mla_decode_fp8.py b/flashinfer/mla/cute_dsl/mla_decode_fp8.py index 242faf3db9..638cc8a5b0 100644 --- a/flashinfer/mla/cute_dsl/mla_decode_fp8.py +++ b/flashinfer/mla/cute_dsl/mla_decode_fp8.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import argparse + import math from typing import Type, Tuple, Optional from types import SimpleNamespace @@ -4219,227 +4219,3 @@ def generate_tensors(): ) return avg_time_us # Return execution time in microseconds - - -if __name__ == "__main__": - - def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: - try: - return tuple(int(x.strip()) for x in s.split(",")) - except ValueError: - raise argparse.ArgumentTypeError( # noqa: B904 - "Invalid format. Expected comma-separated integers." - ) - - def parse_mma_tiler(s: str) -> Tuple[int, int, Tuple[int, int]]: - ret = parse_comma_separated_ints(s) - if len(ret) != 2: - raise argparse.ArgumentTypeError( - "Invalid format. Expected 2 comma-separated integers." - ) - return (ret[0], ret[1]) # type: ignore[return-value] - - parser = argparse.ArgumentParser(description="Example of MLA on Blackwell.") - - parser.add_argument( - "--in_dtype", - type=cutlass.dtype, - default=cutlass.Float8E4M3FN, - help="Input data type", - ) - - parser.add_argument( - "--out_dtype", - type=cutlass.dtype, - default=cutlass.Float8E4M3FN, - help="Output data type", - ) - - parser.add_argument( - "--acc_dtype", - type=cutlass.dtype, - default=cutlass.Float32, - help="Accumulator data type", - ) - - parser.add_argument( - "--lse_dtype", - type=cutlass.dtype, - default=cutlass.Float32, - help="LSE data type", - ) - parser.add_argument( - "--mma_qk_tiler_mn", - type=parse_mma_tiler, - default=(128, 128), - help="MMA tile shape (H, K)", - ) - parser.add_argument( - "--mma_pv_tiler_mn", - type=parse_mma_tiler, - default=(128, 256), - help="MMA tile shape (H, D)", - ) - - parser.add_argument( - "--is_persistent", - action="store_true", - help="Is persistent", - ) - - parser.add_argument( - "--batch_size", - type=int, - default=1, - help="Batch size", - ) - - parser.add_argument( - "--seq_len_q", - type=int, - default=1, - help="Sequence length of Q", - ) - - parser.add_argument( - "--seq_len_k", - type=int, - default=128, - help="Sequence length of K/V", - ) - - parser.add_argument( - "--num_heads", - type=int, - default=128, - help="Number of heads of Q", - ) - - parser.add_argument( - "--latent_dim", - type=int, - default=512, - help="Latent dimension of Q/C", - ) - - parser.add_argument( - "--rope_dim", - type=int, - default=64, - help="Rope dimension of Q/C", - ) - - parser.add_argument( - "--is_var_seq", - action="store_true", - help="Use variable length of sequence length or not", - ) - - parser.add_argument( - "--is_var_split_kv", - action="store_true", - help="Use variable length of split kv or not", - ) - - parser.add_argument( - "--page_size", - type=int, - default=128, - help="Page size of page table", - ) - - parser.add_argument( - "--split_kv", - type=int, - default=-1, - help="Split KV setting", - ) - - parser.add_argument( - "--softmax_scale", - type=float, - default=0.0416, - help="Scaling factor to scale softmax", - ) - - parser.add_argument( - "--output_scale", - type=float, - default=1.0, - help="Scaling factor to scale output", - ) - parser.add_argument( - "--skip_correction_threshold", - type=float, - default=0.0, - help="Threshold to skip correction", - ) - - parser.add_argument( - "--tolerance", type=float, default=1e-02, help="Tolerance for validation" - ) - - parser.add_argument( - "--warmup_iterations", - type=int, - default=0, - help="Number of iterations for warmup", - ) - - parser.add_argument( - "--iterations", - type=int, - default=1, - help="Number of iterations after warmup", - ) - - parser.add_argument( - "--skip_ref_check", - action="store_true", - help="Skip reference check", - ) - - parser.add_argument( - "--use_cold_l2", - action="store_true", - help="Use cold L2 cache", - ) - - parser.add_argument( - "--enable_pdl", - action="store_true", - help="Enable PDL", - ) - - args = parser.parse_args() - - run( - args.batch_size, - args.seq_len_q, - args.seq_len_k, - args.num_heads, - args.latent_dim, - args.rope_dim, - args.in_dtype, - args.out_dtype, - args.acc_dtype, - args.lse_dtype, - args.mma_qk_tiler_mn, - args.mma_pv_tiler_mn, - args.split_kv, - args.is_persistent, - args.is_var_seq, - args.is_var_split_kv, - args.page_size, - args.softmax_scale, - args.output_scale, - args.skip_correction_threshold, - args.tolerance, - args.warmup_iterations, - args.iterations, - args.skip_ref_check, - args.use_cold_l2, - args.enable_pdl, - ) - - print("PASS")